Skip to content

Commit

Permalink
Reference Frame Transformation Operators (GalacticDynamics#135)
Browse files Browse the repository at this point in the history
* feat: frame operations
* feat: not public time-dep translation class
* fix: tests

Signed-off-by: nstarman <nstarman@users.noreply.github.com>
  • Loading branch information
nstarman committed Mar 3, 2024
1 parent 0841812 commit 9264b0b
Show file tree
Hide file tree
Showing 13 changed files with 320 additions and 21 deletions.
17 changes: 8 additions & 9 deletions src/galax/coordinates/operators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import equinox as eqx
from jaxtyping import Shaped
from plum import dispatch
from plum import convert, dispatch

from coordinax import Abstract3DVector, AbstractVector, Cartesian3DVector, FourVector
from jax_quantity import Quantity
Expand Down Expand Up @@ -113,7 +113,7 @@ def __call__(
q: Shaped[Quantity["length"], "*batch 3"],
t: Quantity["time"],
/,
) -> tuple[Abstract3DVector, Quantity["time"]]:
) -> tuple[Shaped[Quantity["length"], "*batch 3"], Quantity["time"]]:
"""Apply the operator to the coordinates.
Examples
Expand All @@ -134,10 +134,11 @@ def __call__(
>>> t = Quantity(0.0, "Gyr")
>>> op(pos, t)
(Cartesian3DVector( ... ),
(Quantity['length'](Array([2., 4., 6.], dtype=float64), unit='kpc'),
Quantity['time'](Array(0., dtype=float64, ...), unit='Gyr'))
"""
return self(Cartesian3DVector.constructor(q), t)
cart, t = self(Cartesian3DVector.constructor(q), t)
return convert(cart, Quantity), t

@dispatch
def __call__(self: "AbstractOperator", x: FourVector, /) -> FourVector:
Expand Down Expand Up @@ -173,7 +174,7 @@ def __call__(self: "AbstractOperator", x: FourVector, /) -> FourVector:
@dispatch
def __call__(
self: "AbstractOperator", q: Shaped[Quantity["length"], "*#batch 4"], /
) -> FourVector:
) -> Shaped[Quantity["length"], "*#batch 4"]:
"""Apply the operator to the coordinates.
Examples
Expand All @@ -196,11 +197,9 @@ def __call__(
>>> newpos = op(pos)
>>> newpos
FourVector( t=Quantity[PhysicalType('time')](...), q=Cartesian3DVector( ... ) )
>>> newpos.q.x
Quantity['length'](Array(2., dtype=float64), unit='kpc')
Quantity['length'](Array([0., 2., 4., 6.], dtype=float64), unit='kpc')
"""
return self(FourVector.constructor(q))
return convert(self(FourVector.constructor(q)), Quantity)

@dispatch
def __call__(
Expand Down
10 changes: 10 additions & 0 deletions src/galax/coordinates/operators/galilean.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,16 @@ class GalileanSpatialTranslationOperator(AbstractGalileanOperator):
>>> op(q, t)
(Cartesian3DVector( ... ), Quantity['time'](Array(0, dtype=int64, ...), unit='Gyr'))
Translation operators can be used to translate potentials:
>>> import galax.potential as gp
>>> pot = gp.KeplerPotential(m=Quantity(1e12, "Msun"), units="galactic")
>>> translated_pot = gp.PotentialFrame(pot, op)
>>> translated_pot
PotentialFrame(
potential=KeplerPotential( ... ),
operator=OperatorSequence( ... GalileanSpatialTranslationOperator( ... ), ) )
)
"""

translation: Abstract3DVector = eqx.field(
Expand Down
40 changes: 40 additions & 0 deletions src/galax/dynamics/_dynamics/mockstream/_moving.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
"""Moving object. TEMPORARY CLASS.
This class adds time-dependent translations to a potential.
It is NOT careful about the implied changes to velocity, etc.
"""

__all__: list[str] = []


from collections.abc import Callable
from typing import Literal, final

import equinox as eqx

from galax.coordinates.operators import AbstractOperator
from galax.typing import FloatScalar, RealScalar, Vec3


@final
class TimeDependentSpatialTranslationOperator(AbstractOperator):
r"""Operator for time-dependent translation."""

translation: Callable[[FloatScalar], Vec3] = eqx.field()
"""The spatial translation."""

def __call__(self, q: Vec3, t: RealScalar) -> tuple[Vec3, RealScalar]:
"""Do."""
return (q + self.translation(t), t)

@property
def is_inertial(self) -> Literal[False]:
"""Galilean translation is an inertial frame-preserving transformation."""
return False

@property
def inverse(self) -> "TimeDependentSpatialTranslationOperator":
"""The inverse of the operator."""
return TimeDependentSpatialTranslationOperator(
translation=lambda t: -self.translation(t)
)
3 changes: 3 additions & 0 deletions src/galax/potential/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ __all__ = [
"TriaxialHernquistPotential",
# special
"MilkyWayPotential",
# frame
"PotentialFrame",
]

from ._potential import io
Expand All @@ -45,6 +47,7 @@ from ._potential.builtin import (
)
from ._potential.composite import AbstractCompositePotential, CompositePotential
from ._potential.core import AbstractPotential
from ._potential.frame import PotentialFrame
from ._potential.param import (
AbstractParameter,
ConstantParameter,
Expand Down
1 change: 1 addition & 0 deletions src/galax/potential/_potential/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ def _init_units(self) -> None:
# ---------------------------------------
# Potential energy

# TODO: inputs w/ units
# @partial(jax.jit)
# @vectorize_method(signature="(3),()->()")
@abc.abstractmethod
Expand Down
14 changes: 8 additions & 6 deletions src/galax/potential/_potential/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ class BarPotential(AbstractPotential):
_: KW_ONLY
units: UnitSystem = eqx.field(converter=unitsystem, static=True)

# TODO: inputs w/ units
@partial(jax.jit)
@vectorize_method(signature="(3),()->()")
def _potential_energy(self, q: Vec3, /, t: RealScalarLike) -> FloatScalar:
Expand Down Expand Up @@ -102,7 +103,7 @@ class HernquistPotential(AbstractPotential):
units: UnitSystem = eqx.field(converter=unitsystem, static=True)

@partial(jax.jit)
def _potential_energy(
def _potential_energy( # TODO: inputs w/ units
self, q: BatchVec3, /, t: BatchableRealScalarLike
) -> BatchFloatScalar:
r = xp.linalg.vector_norm(q, axis=-1)
Expand All @@ -122,7 +123,7 @@ class IsochronePotential(AbstractPotential):
units: UnitSystem = eqx.field(converter=unitsystem, static=True)

@partial(jax.jit)
def _potential_energy(
def _potential_energy( # TODO: inputs w/ units
self, q: BatchVec3, /, t: BatchableRealScalarLike
) -> BatchFloatScalar:
r = xp.linalg.vector_norm(q, axis=-1)
Expand All @@ -146,7 +147,7 @@ class KeplerPotential(AbstractPotential):
units: UnitSystem = eqx.field(converter=unitsystem, static=True)

@partial(jax.jit)
def _potential_energy(
def _potential_energy( # TODO: inputs w/ units
self, q: BatchVec3, /, t: BatchableRealScalarLike
) -> BatchFloatScalar:
r = xp.linalg.vector_norm(q, axis=-1)
Expand All @@ -166,6 +167,7 @@ class MiyamotoNagaiPotential(AbstractPotential):
_: KW_ONLY
units: UnitSystem = eqx.field(converter=unitsystem, static=True)

# TODO: inputs w/ units
@partial(jax.jit)
@vectorize_method(signature="(3),()->()")
def _potential_energy(self, q: Vec3, /, t: RealScalarLike) -> FloatScalar:
Expand All @@ -191,7 +193,7 @@ class NFWPotential(AbstractPotential):
units: UnitSystem = eqx.field(converter=unitsystem, static=True)

@partial(jax.jit)
def _potential_energy(
def _potential_energy( # TODO: inputs w/ units
self, q: BatchVec3, /, t: BatchableRealScalarLike
) -> BatchFloatScalar:
v_h2 = -self._G * self.m(t) / self.r_s(t)
Expand All @@ -211,7 +213,7 @@ class NullPotential(AbstractPotential):
units: UnitSystem = eqx.field(converter=unitsystem, static=True)

@partial(jax.jit)
def _potential_energy(
def _potential_energy( # TODO: inputs w/ units
self,
q: BatchVec3,
/,
Expand Down Expand Up @@ -292,7 +294,7 @@ class TriaxialHernquistPotential(AbstractPotential):
"""The unit system to use for the potential."""

@partial(jax.jit)
def _potential_energy(
def _potential_energy( # TODO: inputs w/ units
self, q: BatchVec3, /, t: BatchableRealScalarLike
) -> BatchFloatScalar:
c, q1, q2 = self.c(t), self.q1(t), self.q2(t)
Expand Down
2 changes: 1 addition & 1 deletion src/galax/potential/_potential/composite.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class AbstractCompositePotential(
# === Potential ===

@partial(jax.jit)
def _potential_energy(
def _potential_energy( # TODO: inputs w/ units
self, q: BatchVec3, /, t: BatchableRealScalarLike
) -> BatchFloatScalar:
return xp.sum(
Expand Down
1 change: 1 addition & 0 deletions src/galax/potential/_potential/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def __post_init__(self) -> None:
###########################################################################
# Abstract methods that must be implemented by subclasses

# TODO: inputs w/ units
@abc.abstractmethod
def _potential_energy(self, q: Vec3, /, t: RealScalar) -> FloatScalar:
raise NotImplementedError
Expand Down
Loading

0 comments on commit 9264b0b

Please sign in to comment.