Skip to content

Commit

Permalink
refactor: abstractorbit methods (GalacticDynamics#221)
Browse files Browse the repository at this point in the history
Signed-off-by: nstarman <nstarman@users.noreply.github.com>
  • Loading branch information
nstarman committed Mar 14, 2024
1 parent e52af23 commit aa1c572
Showing 1 changed file with 53 additions and 1 deletion.
54 changes: 53 additions & 1 deletion src/galax/dynamics/_dynamics/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
__all__ = ["AbstractOrbit"]

from dataclasses import replace
from functools import partial
from typing import TYPE_CHECKING, Any, overload

import equinox as eqx
import jax
import jax.numpy as jnp

from coordinax import Abstract3DVector, Abstract3DVectorDifferential
Expand All @@ -19,7 +21,8 @@
_q_converter,
getitem_vec1time_index,
)
from galax.typing import QVec1, QVecTime
from galax.potential._potential.base import AbstractPotentialBase
from galax.typing import BatchFloatQScalar, QVec1, QVecTime
from galax.utils._shape import batched_shape, vector_batched_shape

if TYPE_CHECKING:
Expand All @@ -44,6 +47,9 @@ class AbstractOrbit(AbstractPhaseSpacePosition):
t: QVecTime | QVec1 = eqx.field(converter=Quantity["time"].constructor)
"""Array of times corresponding to the positions."""

potential: AbstractPotentialBase
"""Potential in which the orbit was integrated."""

def __post_init__(self) -> None:
"""Post-initialization."""
# Need to ensure t shape is correct. Can be Vec0.
Expand Down Expand Up @@ -82,3 +88,49 @@ def __getitem__(self, index: Any) -> "Self | PhaseSpacePosition":
subindex = getitem_vec1time_index(index, self.t)
# Apply slice
return replace(self, q=self.q[index], p=self.p[index], t=self.t[subindex])

# ==========================================================================
# Dynamical quantities

@partial(jax.jit)
def potential_energy(
self, potential: AbstractPotentialBase | None = None, /
) -> BatchFloatQScalar:
r"""Return the specific potential energy.
.. math::
E_\Phi = \Phi(\boldsymbol{q})
Parameters
----------
potential : `galax.potential.AbstractPotentialBase`
The potential object to compute the energy from.
Returns
-------
E : Array[float, (*batch,)]
The specific potential energy.
"""
if potential is None:
return self.potential.potential_energy(self.q, t=self.t)
return potential.potential_energy(self.q, t=self.t)

@partial(jax.jit)
def energy(
self, potential: "AbstractPotentialBase | None" = None, /
) -> BatchFloatQScalar:
r"""Return the specific total energy.
.. math::
E_K = \frac{1}{2} \\, |\boldsymbol{v}|^2
E_\Phi = \Phi(\boldsymbol{q})
E = E_K + E_\Phi
Returns
-------
E : Array[float, (*batch,)]
The kinetic energy.
"""
return self.kinetic_energy() + self.potential_energy(potential)

0 comments on commit aa1c572

Please sign in to comment.