Skip to content

Commit

Permalink
Parse inputs (GalacticDynamics#86)
Browse files Browse the repository at this point in the history
* Add function convert_inputs
* adjust tests

Signed-off-by: nstarman <nstarman@users.noreply.github.com>
  • Loading branch information
nstarman committed Jan 24, 2024
1 parent 8e08e62 commit 26ca62b
Show file tree
Hide file tree
Showing 4 changed files with 143 additions and 18 deletions.
4 changes: 3 additions & 1 deletion src/galax/potential/_potential/__init__.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
"""galax: Galactic Dynamix in Jax."""


from . import base, builtin, composite, core, param
from . import base, builtin, composite, core, param, utils
from .base import *
from .builtin import *
from .composite import *
from .core import *
from .param import *
from .utils import *

__all__: list[str] = []
__all__ += base.__all__
__all__ += core.__all__
__all__ += composite.__all__
__all__ += param.__all__
__all__ += builtin.__all__
__all__ += utils.__all__
50 changes: 36 additions & 14 deletions src/galax/potential/_potential/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@
from dataclasses import KW_ONLY, fields, replace
from typing import TYPE_CHECKING, Any

import astropy.units as u
import equinox as eqx
import jax.experimental.array_api as xp
import jax.numpy as jnp
from astropy.constants import G as _G # pylint: disable=no-name-in-module
from astropy.coordinates import BaseRepresentation
from astropy.units import Quantity
from jax import grad, hessian, jacfwd
from jaxtyping import Array, Float

Expand All @@ -31,6 +32,8 @@
from galax.utils._shape import batched_shape, expand_arr_dims, expand_batch_dims
from galax.utils.dataclasses import ModuleMeta

from .utils import convert_inputs_to_arrays

if TYPE_CHECKING:
from galax.dynamics._orbit import Orbit

Expand All @@ -39,17 +42,17 @@


class AbstractPotentialBase(eqx.Module, metaclass=ModuleMeta, strict=True): # type: ignore[misc]
"""Potential Class."""
"""Abstract Potential Class."""

_: KW_ONLY
units: eqx.AbstractVar[UnitSystem]

###########################################################################
# Abstract methods that must be implemented by subclasses

@abc.abstractmethod
# @partial_jit()
# @vectorize_method(signature="(3),()->()")
@abc.abstractmethod
def _potential_energy(self, q: Vec3, /, t: FloatOrIntScalar) -> FloatScalar:
"""Compute the potential energy at the given position(s).
Expand Down Expand Up @@ -80,7 +83,7 @@ def _init_units(self) -> None:
# Other fields, check their metadata
elif "dimensions" in f.metadata:
value = getattr(self, f.name)
if isinstance(value, u.Quantity):
if isinstance(value, Quantity):
value = value.to_value(
self.units[f.metadata.get("dimensions")],
equivalencies=f.metadata.get("equivalencies", None),
Expand All @@ -94,7 +97,10 @@ def _init_units(self) -> None:
# Potential energy

def potential_energy(
self, q: BatchVec3, /, t: BatchableFloatOrIntScalarLike
self,
q: BatchVec3 | Quantity | BaseRepresentation,
/,
t: BatchableFloatOrIntScalarLike | Quantity,
) -> BatchFloatScalar:
"""Compute the potential energy at the given position(s).
Expand All @@ -110,7 +116,8 @@ def potential_energy(
E : Array[float, *batch]
The potential energy per unit mass or value of the potential.
"""
return self._potential_energy(q, xp.asarray(t))
q, t = convert_inputs_to_arrays(q, t, units=self.units, no_differentials=True)
return self._potential_energy(q, t)

@partial_jit()
def __call__(
Expand Down Expand Up @@ -145,7 +152,12 @@ def _gradient(self, q: Vec3, /, t: FloatOrIntScalar) -> Vec3:
"""See ``gradient``."""
return grad(self._potential_energy)(q, t)

def gradient(self, q: BatchVec3, /, t: BatchableFloatOrIntScalarLike) -> BatchVec3:
def gradient(
self,
q: BatchVec3 | Quantity | BaseRepresentation,
/,
t: BatchableFloatOrIntScalarLike,
) -> BatchVec3:
"""Compute the gradient of the potential at the given position(s).
Parameters
Expand All @@ -162,7 +174,8 @@ def gradient(self, q: BatchVec3, /, t: BatchableFloatOrIntScalarLike) -> BatchVe
grad : Array[float, (*batch, 3)]
The gradient of the potential.
"""
return self._gradient(q, xp.asarray(t)) # vectorize doesn't allow kwargs
q, t = convert_inputs_to_arrays(q, t, units=self.units, no_differentials=True)
return self._gradient(q, t) # vectorize doesn't allow kwargs

# ---------------------------------------
# Density
Expand Down Expand Up @@ -194,7 +207,8 @@ def density(
rho : Array[float, *batch]
The potential energy or value of the potential.
"""
return self._density(q, xp.asarray(t))
q, t = convert_inputs_to_arrays(q, t, units=self.units, no_differentials=True)
return self._density(q, t)

# ---------------------------------------
# Hessian
Expand All @@ -206,7 +220,10 @@ def _hessian(self, q: Vec3, /, t: FloatOrIntScalar) -> Matrix33:
return hessian(self._potential_energy)(q, t)

def hessian(
self, q: BatchVec3, /, t: BatchableFloatOrIntScalarLike
self,
q: BatchVec3 | Quantity | BaseRepresentation,
/,
t: BatchableFloatOrIntScalarLike,
) -> BatchMatrix33:
"""Compute the Hessian of the potential at the given position(s).
Expand All @@ -224,13 +241,17 @@ def hessian(
Array[float, (*batch, 3, 3)]
The Hessian matrix of second derivatives of the potential.
"""
return self._hessian(q, xp.asarray(t))
q, t = convert_inputs_to_arrays(q, t, units=self.units, no_differentials=True)
return self._hessian(q, t)

###########################################################################
# Convenience methods

def acceleration(
self, q: BatchVec3, /, t: BatchableFloatOrIntScalarLike
self,
q: BatchVec3 | Quantity | BaseRepresentation,
/,
t: BatchableFloatOrIntScalarLike,
) -> BatchVec3:
"""Compute the acceleration due to the potential at the given position(s).
Expand All @@ -247,7 +268,8 @@ def acceleration(
The acceleration. Will have the same shape as the input
position array, ``q``.
"""
return -self._gradient(q, xp.asarray(t))
q, t = convert_inputs_to_arrays(q, t, units=self.units, no_differentials=True)
return -self._gradient(q, t)

@partial_jit()
def tidal_tensor(
Expand Down Expand Up @@ -297,7 +319,7 @@ def _integrator_F(
def integrate_orbit(
self,
qp0: BatchVec6,
t: Float[Array, "time"],
t: Float[Array, "time"] | Quantity,
*,
integrator: Integrator | None = None,
) -> "Orbit":
Expand Down
102 changes: 101 additions & 1 deletion src/galax/potential/_potential/utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
"""galax: Galactic Dynamix in Jax."""

__all__: list[str] = []

from functools import singledispatch
from typing import Any
from typing import Any, TypeVar

import jax.numpy as xp
from astropy.coordinates import BaseRepresentation, BaseRepresentationOrDifferential
from astropy.units import Quantity
from jax import Array

from galax.units import UnitSystem, dimensionless, galactic, solarsystem

Expand Down Expand Up @@ -41,3 +47,97 @@ def _from_named(value: str, /) -> UnitSystem:

msg = f"cannot convert {value} to a UnitSystem"
raise NotImplementedError(msg)


# =============================================================================


def convert_inputs_to_arrays(
*args: Any, units: UnitSystem, **kwargs: Any
) -> tuple[Array, ...]:
"""Parse input arguments.
Parameters
----------
*args : Any, positional-only
Input arguments to parse to arrays.
units : UnitSystem, keyword-only
Unit system.
**kwargs : Any
Additional keyword arguments.
Returns
-------
tuple[Array, ...]
Parsed input arguments.
"""
return tuple(convert_input_to_array(arg, units=units, **kwargs) for arg in args)


# --------------------------------------------------------------

Value = TypeVar("Value", int, float, Array)


@singledispatch
def convert_input_to_array(value: Any, /, *, units: UnitSystem, **kwargs: Any) -> Any:
"""Parse input arguments.
This function uses :func:`~functools.singledispatch` to dispatch on the type
of the input argument.
Parameters
----------
value : Any, positional-only
Input value.
units : UnitSystem, keyword-only
Unit system.
**kwargs : Any
Additional keyword arguments.
Returns
-------
Any
Parsed input value.
"""
msg = f"cannot convert {value} using units {units}"
raise NotImplementedError(msg)


@convert_input_to_array.register(int)
@convert_input_to_array.register(float)
@convert_input_to_array.register(Array)
def _convert_from_arraylike(
value: Value, /, *, units: UnitSystem, **kwargs: Any
) -> Array:
return xp.asarray(value)


@convert_input_to_array.register(Quantity)
def _convert_from_quantity(
value: Quantity, /, *, units: UnitSystem, **kwargs: Any
) -> Array:
return xp.asarray(value.decompose(units).value)


@convert_input_to_array.register(BaseRepresentationOrDifferential)
def _convert_from_baserep(
value: BaseRepresentationOrDifferential, /, *, units: UnitSystem, **kwargs: Any
) -> Array:
return xp.stack(
[getattr(value, attr).decompose(units).value for attr in value.components]
)


@convert_input_to_array.register(BaseRepresentation)
def _convert_from_representation(
value: BaseRepresentation, /, *, units: UnitSystem, **kwargs: Any
) -> Array:
if "s" in value.differentials and not kwargs.get("no_differentials", False):
return xp.stack(
(
_convert_from_baserep(value, units=units),
_convert_from_baserep(value.differentials["s"], units=units),
)
)
return _convert_from_baserep(value, units=units)
5 changes: 3 additions & 2 deletions tests/unit/potential/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import galax.dynamics as gd
import galax.potential as gp
from galax.typing import BatchableFloatOrIntScalarLike, BatchFloatScalar, BatchVec3
from galax.units import UnitSystem, dimensionless
from galax.utils import partial_jit, vectorize_method


Expand All @@ -19,7 +20,7 @@ class TestAbstractPotentialBase:
@pytest.fixture(scope="class")
def pot_cls(self) -> type[gp.AbstractPotentialBase]:
class TestPotential(gp.AbstractPotentialBase):
units: float = 2
units: UnitSystem = eqx.field(default=dimensionless, static=True)
_G: float = eqx.field(init=False, static=True, repr=False, converter=float)

def __post_init__(self):
Expand Down Expand Up @@ -81,7 +82,7 @@ def test_init(self):

# Test that the concrete class can be instantiated
class TestPotential(gp.AbstractPotentialBase):
units: float = 2
units: UnitSystem = eqx.field(default=dimensionless, static=True)

def _potential_energy(self, q, t):
return xp.sum(q, axis=-1)
Expand Down

0 comments on commit 26ca62b

Please sign in to comment.