Skip to content

Commit

Permalink
Add MilkyWayPotential (GalacticDynamics#91)
Browse files Browse the repository at this point in the history
* separate out a AbstractCompositePotential
* Add a MilkyWayPotential
* consolidate test_acceleration
* test compositepotential
* tests for MW potential
* smoke test
* finalize MWPotential

Signed-off-by: nstarman <nstarman@users.noreply.github.com>
  • Loading branch information
nstarman committed Jan 25, 2024
1 parent ac6b6a9 commit a229d79
Show file tree
Hide file tree
Showing 10 changed files with 733 additions and 40 deletions.
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ xfail_strict = true
filterwarnings = [
"error",
"ignore:ast\\.Str is deprecated:DeprecationWarning",
"ignore:numpy\\.ndarray size changed:RuntimeWarning",
]
log_cli_level = "INFO"
testpaths = [
Expand All @@ -112,8 +113,8 @@ port.exclude_lines = [
[tool.mypy]
files = ["src"]
python_version = "3.11"
warn_unused_configs = true
strict = true
warn_unused_configs = true
show_error_codes = true
warn_unreachable = true
disallow_untyped_defs = true
Expand Down
4 changes: 3 additions & 1 deletion src/galax/potential/_potential/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
"""galax: Galactic Dynamix in Jax."""


from . import base, builtin, composite, core, param, utils
from . import base, builtin, composite, core, param, special, utils
from .base import *
from .builtin import *
from .composite import *
from .core import *
from .param import *
from .special import *
from .utils import *

__all__: list[str] = []
Expand All @@ -15,4 +16,5 @@
__all__ += composite.__all__
__all__ += param.__all__
__all__ += builtin.__all__
__all__ += special.__all__
__all__ += utils.__all__
72 changes: 40 additions & 32 deletions src/galax/potential/_potential/composite.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

import equinox as eqx
import jax.experimental.array_api as xp
from typing_extensions import override

from galax.typing import (
BatchableFloatOrIntScalarLike,
Expand All @@ -25,36 +24,10 @@
V = TypeVar("V")


@final
class CompositePotential(ImmutableDict[AbstractPotentialBase], AbstractPotentialBase):
"""Composite Potential."""

_data: dict[str, AbstractPotentialBase]
_: KW_ONLY
units: UnitSystem = eqx.field(init=False, static=True, converter=converter_to_usys)
_G: float = eqx.field(init=False, static=True, repr=False, converter=float)

def __init__(
self,
potentials: dict[str, AbstractPotentialBase]
| tuple[tuple[str, AbstractPotentialBase], ...] = (),
/,
**kwargs: AbstractPotentialBase,
) -> None:
kwunits = kwargs.pop("units", None)
super().__init__(potentials, **kwargs)

# __post_init__ stuff:
# Check that all potentials have the same unit system
units = kwunits if kwunits is not None else first(self.values()).units
if not all(p.units == units for p in self.values()):
msg = "all potentials must have the same unit system"
raise ValueError(msg)
object.__setattr__(self, "units", units)

# Apply the unit system to any parameters.
self._init_units()

# Note: cannot have `strict=True` because of inheriting from ImmutableDict.
class AbstractCompositePotential(
ImmutableDict[AbstractPotentialBase], AbstractPotentialBase, strict=False
):
# === Potential ===

@partial_jit()
Expand All @@ -66,7 +39,6 @@ def _potential_energy(
###########################################################################
# Composite potentials

@override
def __or__(self, other: Any) -> "CompositePotential":
if not isinstance(other, AbstractPotentialBase):
return NotImplemented
Expand Down Expand Up @@ -95,3 +67,39 @@ def __ror__(self, other: Any) -> "CompositePotential":

def __add__(self, other: AbstractPotentialBase) -> "CompositePotential":
return self | other


###########################################################################


@final
class CompositePotential(AbstractCompositePotential):
"""Composite Potential."""

_data: dict[str, AbstractPotentialBase]
_: KW_ONLY
units: UnitSystem = eqx.field(init=False, static=True, converter=converter_to_usys)
_G: float = eqx.field(init=False, static=True, repr=False, converter=float)

def __init__(
self,
potentials: dict[str, AbstractPotentialBase]
| tuple[tuple[str, AbstractPotentialBase], ...] = (),
/,
*,
units: Any = None,
**kwargs: AbstractPotentialBase,
) -> None:
super().__init__(potentials, **kwargs)

# __post_init__ stuff:
# Check that all potentials have the same unit system
units_ = units if units is not None else first(self.values()).units
usys = converter_to_usys(units_)
if not all(p.units == usys for p in self.values()):
msg = "all potentials must have the same unit system"
raise ValueError(msg)
object.__setattr__(self, "units", usys)

# Apply the unit system to any parameters.
self._init_units()
102 changes: 102 additions & 0 deletions src/galax/potential/_potential/special.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
"""galax: Galactic Dynamix in Jax."""

__all__ = ["MilkyWayPotential"]


from dataclasses import KW_ONLY
from typing import Any, final

import astropy.units as u
import equinox as eqx
from astropy.units import Quantity

from galax.units import UnitSystem, dimensionless, galactic

from .base import AbstractPotentialBase
from .builtin import HernquistPotential, MiyamotoNagaiPotential, NFWPotential
from .composite import AbstractCompositePotential
from .utils import converter_to_usys

_default_disk = {"m": 6.8e10 * u.Msun, "a": 3.0 * u.kpc, "b": 0.28 * u.kpc}
_default_halo = {"m": 5.4e11 * u.Msun, "r_s": 15.62 * u.kpc}
_default_bulge = {"m": 5e9 * u.Msun, "c": 1.0 * u.kpc}
_default_nucleus = {"m": 1.71e9 * u.Msun, "c": 0.07 * u.kpc}


def _munge(value: dict[str, Quantity], units: UnitSystem) -> Any:
if units == dimensionless:
return {k: v.value for k, v in value.items()}
return value


@final
class MilkyWayPotential(AbstractCompositePotential):
"""Milky Way mass model.
A simple mass-model for the Milky Way consisting of a spherical nucleus and
bulge, a Miyamoto-Nagai disk, and a spherical NFW dark matter halo.
The disk model is taken from `Bovy (2015)
<https://ui.adsabs.harvard.edu/#abs/2015ApJS..216...29B/abstract>`_ - if you
use this potential, please also cite that work.
Default parameters are fixed by fitting to a compilation of recent mass
measurements of the Milky Way, from 10 pc to ~150 kpc.
Parameters
----------
units : `~galax.units.UnitSystem` (optional)
Set of non-reducable units that specify (at minimum) the
length, mass, time, and angle units.
disk : dict (optional)
Parameters to be passed to the :class:`~galax.potential.MiyamotoNagaiPotential`.
bulge : dict (optional)
Parameters to be passed to the :class:`~galax.potential.HernquistPotential`.
halo : dict (optional)
Parameters to be passed to the :class:`~galax.potential.NFWPotential`.
nucleus : dict (optional)
Parameters to be passed to the :class:`~galax.potential.HernquistPotential`.
Note: in subclassing, order of arguments must match order of potential
components added at bottom of init.
"""

_data: dict[str, AbstractPotentialBase] = eqx.field(init=False)
_: KW_ONLY
units: UnitSystem = eqx.field(init=True, static=True, converter=converter_to_usys)
_G: float = eqx.field(init=False, static=True, repr=False, converter=float)

def __init__(
self,
*,
units: Any = galactic,
disk: dict[str, Any] | None = None,
halo: dict[str, Any] | None = None,
bulge: dict[str, Any] | None = None,
nucleus: dict[str, Any] | None = None,
) -> None:
units_ = converter_to_usys(units) if units is not None else galactic
super().__init__(
disk=MiyamotoNagaiPotential(
units=units_, **_munge(_default_disk, units_) | (disk or {})
),
halo=NFWPotential(
units=units_, **_munge(_default_halo, units_) | (halo or {})
),
bulge=HernquistPotential(
units=units_, **_munge(_default_bulge, units_) | (bulge or {})
),
nucleus=HernquistPotential(
units=units_, **_munge(_default_nucleus, units_) | (nucleus or {})
),
)

# __post_init__ stuff:
# Check that all potentials have the same unit system
if not all(p.units == units_ for p in self.values()):
msg = "all potentials must have the same unit system"
raise ValueError(msg)
object.__setattr__(self, "units", units_)

# Apply the unit system to any parameters.
self._init_units()
1 change: 1 addition & 0 deletions tests/smoke/potential/test_package.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,5 @@ def test_all():
+ _potential.composite.__all__
+ _potential.core.__all__
+ _potential.param.__all__
+ _potential.special.__all__
)
58 changes: 58 additions & 0 deletions tests/unit/potential/builtin/test_mwpotential.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
from typing import Any

import jax.numpy as xp
import pytest

from galax.potential import MilkyWayPotential
from galax.units import galactic

from ..test_core import TestAbstractPotential


class TestMilkyWayPotentialDefault(TestAbstractPotential):
"""Test the Milky Way potential with default parameters."""

@pytest.fixture(scope="class")
def pot_cls(self) -> type[MilkyWayPotential]:
return MilkyWayPotential

@pytest.fixture(scope="class")
def fields_(self, field_units) -> dict[str, Any]:
return {"units": field_units}

# ==========================================================================

def test_init_units_from_args(self, pot_cls, fields_unitless):
"""Test unit system from None."""
# strip the units from the fields otherwise the test will fail
# because the units are not equal and we just want to check that
# when the units aren't specified, the default is dimensionless
# and a numeric value works.
fields_unitless.pop("units")
pot = pot_cls(**fields_unitless, units=None)
assert pot.units == galactic

# ==========================================================================

def test_potential_energy(self, pot, x) -> None:
assert xp.isclose(pot.potential_energy(x, t=0), xp.array(-0.19386052))

def test_gradient(self, pot, x):
assert xp.allclose(
pot.gradient(x, t=0), xp.array([0.00256403, 0.00512806, 0.01115272])
)

def test_density(self, pot, x):
assert xp.isclose(pot.density(x, t=0), 33_365_858.46361218)

def test_hessian(self, pot, x):
assert xp.allclose(
pot.hessian(x, t=0),
xp.array(
[
[0.00231054, -0.00050698, -0.00101273],
[-0.00050698, 0.00155006, -0.00202546],
[-0.00101273, -0.00202546, -0.00197444],
]
),
)
1 change: 0 additions & 1 deletion tests/unit/potential/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,6 @@ def test_hessian(self, pot, x):

def test_acceleration(self, pot, x):
"""Test the `AbstractPotentialBase.acceleration` method."""
assert array_equal(pot.acceleration(x, t=0), xp.asarray([-1.0, -1, -1]))
assert array_equal(pot.acceleration(x, t=0), -pot.gradient(x, t=0))

# =========================================================================
Expand Down
Loading

0 comments on commit a229d79

Please sign in to comment.