Skip to content

Commit

Permalink
feat(pot): LM10Potential and LMJ09logarithmic (GalacticDynamics#296)
Browse files Browse the repository at this point in the history
* feat(pot): LM10Potential and LMJ09logarithmic
* refactor(pot): move around potentials and tests
* fix: kuxmin gala compat
* fix: longmuralibar compat

Signed-off-by: nstarman <nstarman@users.noreply.github.com>
  • Loading branch information
nstarman committed May 6, 2024
1 parent f3c44e5 commit d72457c
Show file tree
Hide file tree
Showing 48 changed files with 963 additions and 344 deletions.
2 changes: 1 addition & 1 deletion src/galax/coordinates/_psp/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
if TYPE_CHECKING:
from typing import Self

from galax.potential._potential.base import AbstractPotentialBase
from galax.potential import AbstractPotentialBase


class ComponentShapeTuple(NamedTuple):
Expand Down
2 changes: 1 addition & 1 deletion src/galax/dynamics/_dynamics/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
_q_converter,
getitem_vec1time_index,
)
from galax.potential._potential.base import AbstractPotentialBase
from galax.potential import AbstractPotentialBase
from galax.typing import BatchFloatQScalar, QVec1, QVecTime
from galax.utils._shape import batched_shape, vector_batched_shape

Expand Down
2 changes: 1 addition & 1 deletion src/galax/dynamics/_dynamics/integrate/_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from ._builtin import DiffraxIntegrator
from galax.coordinates import PhaseSpacePosition
from galax.dynamics._dynamics.orbit import InterpolatedOrbit, Orbit
from galax.potential._potential.base import AbstractPotentialBase
from galax.potential import AbstractPotentialBase

##############################################################################

Expand Down
2 changes: 1 addition & 1 deletion src/galax/dynamics/_dynamics/mockstream/df/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from ._progenitor import ConstantMassProtenitor, ProgenitorMassCallable
from galax.dynamics._dynamics.mockstream.core import MockStream
from galax.dynamics._dynamics.orbit import Orbit
from galax.potential._potential.base import AbstractPotentialBase
from galax.potential import AbstractPotentialBase

Wif: TypeAlias = tuple[
gt.LengthVec3,
Expand Down
1 change: 0 additions & 1 deletion src/galax/dynamics/_dynamics/mockstream/df/_fardal.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import galax.typing as gt
from ._base import AbstractStreamDF
from galax.potential import AbstractPotentialBase
from galax.potential._potential.base import AbstractPotentialBase

# ============================================================
# Constants
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
_default_integrator,
evaluate_orbit,
)
from galax.potential._potential.base import AbstractPotentialBase
from galax.potential import AbstractPotentialBase

Carry: TypeAlias = tuple[gt.IntScalar, gt.VecN, gt.VecN]

Expand Down
16 changes: 13 additions & 3 deletions src/galax/potential/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ __all__ = [
"JaffePotential",
"KeplerPotential",
"KuzminPotential",
"LogarithmicPotential",
"LongMuraliBarPotential",
"MiyamotoNagaiPotential",
"NullPotential",
Expand All @@ -27,13 +26,17 @@ __all__ = [
"SatohPotential",
"StoneOstriker15Potential",
"TriaxialHernquistPotential",
# logarithmic
"LogarithmicPotential",
"LMJ09LogarithmicPotential",
# nfw
"NFWPotential",
"LeeSutoTriaxialNFWPotential",
"TriaxialNFWPotential",
"Vogelsberger08TriaxialNFWPotential",
# special
"BovyMWPotential2014",
"LM10Potential",
"MilkyWayPotential",
# frame
"PotentialFrame",
Expand Down Expand Up @@ -63,7 +66,6 @@ from ._potential.builtin.builtin import (
JaffePotential,
KeplerPotential,
KuzminPotential,
LogarithmicPotential,
LongMuraliBarPotential,
MiyamotoNagaiPotential,
NullPotential,
Expand All @@ -73,12 +75,21 @@ from ._potential.builtin.builtin import (
StoneOstriker15Potential,
TriaxialHernquistPotential,
)
from ._potential.builtin.logarithmic import (
LMJ09LogarithmicPotential,
LogarithmicPotential,
)
from ._potential.builtin.nfw import (
LeeSutoTriaxialNFWPotential,
NFWPotential,
TriaxialNFWPotential,
Vogelsberger08TriaxialNFWPotential,
)
from ._potential.builtin.special import (
BovyMWPotential2014,
LM10Potential,
MilkyWayPotential,
)
from ._potential.composite import AbstractCompositePotential, CompositePotential
from ._potential.core import AbstractPotential
from ._potential.frame import PotentialFrame
Expand All @@ -99,4 +110,3 @@ from ._potential.param import (
ParametersAttribute,
UserParameter,
)
from ._potential.special import BovyMWPotential2014, MilkyWayPotential
6 changes: 5 additions & 1 deletion src/galax/potential/_potential/builtin/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
"""``galax`` Potentials."""
# ruff:noqa: F401

from . import builtin, nfw
from . import builtin, logarithmic, nfw, special
from .builtin import *
from .logarithmic import *
from .nfw import *
from .special import *

__all__: list[str] = []
__all__ += builtin.__all__
__all__ += logarithmic.__all__
__all__ += nfw.__all__
__all__ += special.__all__
94 changes: 94 additions & 0 deletions src/galax/potential/_potential/builtin/logarithmic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
"""galax: Galactic Dynamix in Jax."""

__all__ = [
"LogarithmicPotential",
"LMJ09LogarithmicPotential",
]

from dataclasses import KW_ONLY
from functools import partial
from typing import final

import equinox as eqx
import jax

import quaxed.array_api as xp
from unxt import AbstractUnitSystem, Quantity, unitsystem

import galax.typing as gt
from galax.potential._potential.base import default_constants
from galax.potential._potential.core import AbstractPotential
from galax.potential._potential.param import AbstractParameter, ParameterField
from galax.utils import ImmutableDict


@final
class LogarithmicPotential(AbstractPotential):
"""Logarithmic Potential."""

v_c: AbstractParameter = ParameterField(dimensions="speed") # type: ignore[assignment]
r_s: AbstractParameter = ParameterField(dimensions="length") # type: ignore[assignment]

_: KW_ONLY
units: AbstractUnitSystem = eqx.field(converter=unitsystem, static=True)
constants: ImmutableDict[Quantity] = eqx.field(
default=default_constants, converter=ImmutableDict
)

@partial(jax.jit)
def _potential_energy(
self, q: gt.BatchQVec3, t: gt.BatchableRealQScalar, /
) -> gt.BatchFloatQScalar:
r2 = xp.linalg.vector_norm(q, axis=-1).to_value(self.units["length"]) ** 2
return (
0.5
* self.v_c(t) ** 2
* xp.log(self.r_s(t).to_value(self.units["length"]) ** 2 + r2)
)


@final
class LMJ09LogarithmicPotential(AbstractPotential):
"""Logarithmic Potential from LMJ09.
https://ui.adsabs.harvard.edu/abs/2009ApJ...703L..67L/abstract
"""

v_c: AbstractParameter = ParameterField(dimensions="speed") # type: ignore[assignment]
r_s: AbstractParameter = ParameterField(dimensions="length") # type: ignore[assignment]

q1: AbstractParameter = ParameterField(dimensions="dimensionless") # type: ignore[assignment]
q2: AbstractParameter = ParameterField(dimensions="dimensionless") # type: ignore[assignment]
q3: AbstractParameter = ParameterField(dimensions="dimensionless") # type: ignore[assignment]

phi: AbstractParameter = ParameterField(dimensions="angle") # type: ignore[assignment]

_: KW_ONLY
units: AbstractUnitSystem = eqx.field(converter=unitsystem, static=True)
constants: ImmutableDict[Quantity] = eqx.field(
default=default_constants, converter=ImmutableDict
)

@partial(jax.jit)
def _potential_energy(
self, q: gt.BatchQVec3, t: gt.BatchableRealQScalar, /
) -> gt.BatchFloatQScalar:
# Load parameters
q1, q2, q3 = self.q1(t), self.q2(t), self.q3(t)
phi = self.phi(t)

# Rotated and scaled coordinates
sphi, cphi = xp.sin(phi), xp.cos(phi)
x = q[..., 0] * cphi + q[..., 1] * sphi
y = -q[..., 0] * sphi + q[..., 1] * cphi
r2 = (x / q1) ** 2 + (y / q2) ** 2 + (q[..., 2] / q3) ** 2

# Potential energy
return (
0.5
* self.v_c(t) ** 2
* xp.log(
self.r_s(t).to_value(self.units["length"]) ** 2
+ r2.to_value(self.units["area"])
)
)
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
"""galax: Galactic Dynamix in Jax."""

__all__ = ["BovyMWPotential2014", "MilkyWayPotential"]
__all__ = [
"BovyMWPotential2014",
"LM10Potential",
"MilkyWayPotential",
]


from collections.abc import Mapping
Expand All @@ -10,17 +14,15 @@

import equinox as eqx

import quaxed.array_api as xp
from unxt import Quantity
from unxt.unitsystems import AbstractUnitSystem, dimensionless, galactic, unitsystem

from .base import AbstractPotentialBase, default_constants
from .builtin.builtin import (
HernquistPotential,
MiyamotoNagaiPotential,
PowerLawCutoffPotential,
)
from .builtin.nfw import NFWPotential
from .composite import AbstractCompositePotential
from .builtin import HernquistPotential, MiyamotoNagaiPotential, PowerLawCutoffPotential
from .logarithmic import LMJ09LogarithmicPotential
from .nfw import NFWPotential
from galax.potential._potential.base import AbstractPotentialBase, default_constants
from galax.potential._potential.composite import AbstractCompositePotential
from galax.utils import ImmutableDict

T = TypeVar("T", bound=AbstractPotentialBase)
Expand Down Expand Up @@ -123,6 +125,102 @@ def __init__(
)


_sqrt2 = xp.sqrt(xp.asarray(2.0))


@final
class LM10Potential(AbstractCompositePotential):
"""Law & Majewski (2010) Milky Way mass model.
The Galactic potential used by Law and Majewski (2010) to represent the
Milky Way as a three-component sum of disk, bulge, and halo.
The disk potential is an axisymmetric
:class:`~galax.potential.MiyamotoNagaiPotential`, the bulge potential is a
spherical :class:`~galax.potential.HernquistPotential`, and the halo
potential is a triaxial :class:`~galax.potential.LMJ09LogarithmicPotential`.
Default parameters are fixed to those found in LM10 by fitting N-body
simulations to the Sagittarius stream.
Parameters
----------
units : `~galax.units.UnitSystem` (optional)
Set of non-reducable units that specify (at minimum) the length, mass,
time, and angle units.
disk : dict (optional)
Parameters to be passed to the
:class:`~galax.potential.MiyamotoNagaiPotential`.
bulge : dict (optional)
Parameters to be passed to the
:class:`~galax.potential.HernquistPotential`.
halo : dict (optional)
Parameters to be passed to the
:class:`~galax.potential.LMJ09LogarithmicPotential`.
Note: in subclassing, order of arguments must match order of potential
components added at bottom of init.
"""

_data: dict[str, AbstractPotentialBase] = eqx.field(init=False)
_: KW_ONLY
units: AbstractUnitSystem = eqx.field(
default=galactic, static=True, converter=unitsystem
)
constants: ImmutableDict[Quantity] = eqx.field(
default=default_constants, converter=ImmutableDict
)

# TODO: as an actual `MiyamotoNagaiPotential`, then use `replace`?
_default_disk: ClassVar[Mapping[str, Any]] = MappingProxyType(
{
"m_tot": Quantity(1e11, "Msun"),
"a": Quantity(6.5, "kpc"),
"b": Quantity(0.26, "kpc"),
}
)
# TODO: as an actual `HernquistPotential`, then use `replace`?
_default_bulge: ClassVar[Mapping[str, Any]] = MappingProxyType(
{"m_tot": Quantity(3.4e10, "Msun"), "c": Quantity(0.7, "kpc")}
)
# TODO: as an actual `LMJ09LogarithmicPotential`, then use `replace`?
_default_halo: ClassVar[Mapping[str, Any]] = MappingProxyType(
{
"v_c": Quantity(_sqrt2 * 121.858, "km / s"),
"r_s": Quantity(12.0, "kpc"),
"q1": 1.38,
"q2": 1.0,
"q3": 1.36,
"phi": Quantity(97, "degree"),
}
)

def __init__(
self,
*,
disk: MiyamotoNagaiPotential | Mapping[str, Any] | None = None,
bulge: HernquistPotential | Mapping[str, Any] | None = None,
halo: LMJ09LogarithmicPotential | Mapping[str, Any] | None = None,
units: Any = galactic,
constants: Any = default_constants,
) -> None:
units_ = unitsystem(units) if units is not None else galactic

super().__init__(
disk=_parse_input_comp(
MiyamotoNagaiPotential, disk, self._default_disk, units_
),
bulge=_parse_input_comp(
HernquistPotential, bulge, self._default_bulge, units_
),
halo=_parse_input_comp(
LMJ09LogarithmicPotential, halo, self._default_halo, units_
),
units=units_,
constants=constants,
)


@final
class MilkyWayPotential(AbstractCompositePotential):
"""Milky Way mass model.
Expand Down
2 changes: 1 addition & 1 deletion src/galax/potential/_potential/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from unxt import AbstractUnitSystem, Quantity

import galax.typing as gt
from galax.potential._potential.base import AbstractPotentialBase
from .base import AbstractPotentialBase
from galax.utils import ImmutableDict


Expand Down
Loading

0 comments on commit d72457c

Please sign in to comment.