Skip to content

Commit

Permalink
feat: vogelsberger’08 NFW (GalacticDynamics#282)
Browse files Browse the repository at this point in the history
Signed-off-by: nstarman <nstarman@users.noreply.github.com>
  • Loading branch information
nstarman committed Apr 30, 2024
1 parent b36d077 commit a4054c8
Show file tree
Hide file tree
Showing 12 changed files with 533 additions and 260 deletions.
2 changes: 2 additions & 0 deletions src/galax/potential/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ __all__ = [
"PlummerPotential",
"PowerLawCutoffPotential",
"TriaxialHernquistPotential",
"Vogelsberger08TriaxialNFWPotential",
# special
"BovyMWPotential2014",
"MilkyWayPotential",
Expand All @@ -55,6 +56,7 @@ from ._potential.builtin import (
PlummerPotential,
PowerLawCutoffPotential,
TriaxialHernquistPotential,
Vogelsberger08TriaxialNFWPotential,
)
from ._potential.composite import AbstractCompositePotential, CompositePotential
from ._potential.core import AbstractPotential
Expand Down
9 changes: 5 additions & 4 deletions src/galax/potential/_potential/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
from galax.dynamics._dynamics.orbit import Orbit


BatchRealQScalar: TypeAlias = Shaped[gt.RealQScalar, "*batch"]
QMatrix33: TypeAlias = Float[Quantity, "3 3"]
BatchMatrix33: TypeAlias = Shaped[Float[Array, "3 3"], "*batch"]
BatchQMatrix33: TypeAlias = Shaped[QMatrix33, "*batch"]
Expand All @@ -48,7 +47,7 @@
Abstract3DVector | gt.LengthBroadBatchVec3 | Shaped[Array, "*#batch 3"]
)
TimeOptions: TypeAlias = (
BatchRealQScalar
gt.BatchRealQScalar
| gt.FloatQScalar
| gt.IntQScalar
| gt.BatchableRealScalarLike
Expand Down Expand Up @@ -1016,7 +1015,7 @@ def laplacian(

@partial(jax.jit)
def _density(
self, q: gt.BatchQVec3, /, t: BatchRealQScalar | gt.RealQScalar
self, q: gt.BatchQVec3, /, t: gt.BatchRealQScalar | gt.RealQScalar
) -> gt.BatchFloatQScalar:
"""See ``density``."""
# Note: trace(jacobian(gradient)) is faster than trace(hessian(energy))
Expand Down Expand Up @@ -1842,7 +1841,9 @@ def acceleration(
# Tidal tensor

@partial(jax.jit)
def tidal_tensor(self, q: gt.BatchQVec3, /, t: BatchRealQScalar) -> BatchMatrix33:
def tidal_tensor(
self, q: gt.BatchQVec3, /, t: gt.BatchRealQScalar
) -> BatchMatrix33:
"""Compute the tidal tensor.
See https://en.wikipedia.org/wiki/Tidal_tensor
Expand Down
10 changes: 10 additions & 0 deletions src/galax/potential/_potential/builtin/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
"""``galax`` Potentials."""
# ruff:noqa: F401

from . import builtin, nfw
from .builtin import *
from .nfw import *

__all__: list[str] = []
__all__ += builtin.__all__
__all__ += nfw.__all__
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,8 @@
"IsochronePotential",
"KeplerPotential",
"KuzminPotential",
"LeeSutoTriaxialNFWPotential",
"LogarithmicPotential",
"MiyamotoNagaiPotential",
"NFWPotential",
"NullPotential",
"PlummerPotential",
"PowerLawCutoffPotential",
Expand All @@ -26,7 +24,6 @@
from quax import quaxify

import quaxed.array_api as xp
import quaxed.lax as qlax
import quaxed.scipy.special as qsp
from unxt import AbstractUnitSystem, Quantity, unitsystem
from unxt.unitsystems import galactic
Expand Down Expand Up @@ -281,161 +278,6 @@ def _potential_energy(
# -------------------------------------------------------------------


@final
class NFWPotential(AbstractPotential):
"""NFW Potential."""

m: AbstractParameter = ParameterField(dimensions="mass") # type: ignore[assignment]
"""Mass parameter. This is NOT the total mass."""

r_s: AbstractParameter = ParameterField(dimensions="length") # type: ignore[assignment]
"""Scale radius of the potential."""

_: KW_ONLY
units: AbstractUnitSystem = eqx.field(converter=unitsystem, static=True)
constants: ImmutableDict[Quantity] = eqx.field(
default=default_constants, converter=ImmutableDict
)

@partial(jax.jit)
def _potential_energy( # TODO: inputs w/ units
self, q: gt.BatchQVec3, t: gt.BatchableRealQScalar, /
) -> gt.BatchFloatQScalar:
v_h2 = -self.constants["G"] * self.m(t) / self.r_s(t)
r = xp.linalg.vector_norm(q, axis=-1)
m = r / self.r_s(t)
return v_h2 * xp.log(1.0 + m) / m


_log2 = xp.log(xp.asarray(2.0))


@final
class LeeSutoTriaxialNFWPotential(AbstractPotential):
"""Approximate triaxial (in the density) NFW potential.
Approximation of a Triaxial NFW Potential with the flattening in the
density, not the potential. See Lee & Suto (2003) for details.
.. warning::
This potential is only physical for `a1 >= a2 >= a3`.
Examples
--------
>>> from unxt import Quantity
>>> import galax.potential as gp
>>> pot = gp.LeeSutoTriaxialNFWPotential(
... m=Quantity(1e11, "Msun"), r_s=Quantity(15, "kpc"),
... a1=1, a2=0.9, a3=0.8, units="galactic")
>>> q = Quantity([1, 0, 0], "kpc")
>>> t = Quantity(0, "Gyr")
>>> pot.potential_energy(q, t).decompose(pot.units)
Quantity['specific energy'](Array(-0.14620419, dtype=float64), unit='kpc2 / Myr2')
>>> q = Quantity([0, 1, 0], "kpc")
>>> pot.potential_energy(q, t).decompose(pot.units)
Quantity['specific energy'](Array(-0.14593972, dtype=float64), unit='kpc2 / Myr2')
>>> q = Quantity([0, 0, 1], "kpc")
>>> pot.potential_energy(q, t).decompose(pot.units)
Quantity['specific energy'](Array(-0.14570309, dtype=float64), unit='kpc2 / Myr2')
"""

m: AbstractParameter = ParameterField(dimensions="mass") # type: ignore[assignment]
r"""Scall mass.
This is the mass corresponding to the circular velocity at the scale radius.
:math:`v_c = \sqrt{G M / r_s}`
"""

r_s: AbstractParameter = ParameterField(dimensions="length") # type: ignore[assignment]
"""Scale radius."""

a1: AbstractParameter = ParameterField(
dimensions="dimensionless",
default=Quantity(1.0, ""), # type: ignore[assignment]
)
"""Major axis."""

a2: AbstractParameter = ParameterField(
dimensions="dimensionless",
default=Quantity(1.0, ""), # type: ignore[assignment]
)
"""Intermediate axis."""

a3: AbstractParameter = ParameterField(
dimensions="dimensionless",
default=Quantity(1.0, ""), # type: ignore[assignment]
)
"""Minor axis."""

_: KW_ONLY
units: AbstractUnitSystem = eqx.field(converter=unitsystem, static=True)
constants: ImmutableDict[Quantity] = eqx.field(
default=default_constants, converter=ImmutableDict
)

def __check_init__(self) -> None:
t = Quantity(0.0, "Myr")
_ = eqx.error_if(
t,
(self.a1(t) < self.a2(t)) or (self.a2(t) < self.a3(t)),
"a1 >= a2 >= a3 is required",
)

@partial(jax.jit)
def _potential_energy( # TODO: inputs w/ units
self, q: gt.BatchQVec3, t: gt.BatchableRealQScalar, /
) -> gt.BatchFloatQScalar:
# https://github.com/adrn/gala/blob/2067009de41518a71c674d0252bc74a7b2d78a36/gala/potential/potential/builtin/builtin_potentials.c#L1472
# Evaluate the parameters
r_s = self.r_s(t)
v_c2 = self.constants["G"] * self.m(t) / r_s
a1, a2, a3 = self.a1(t), self.a2(t), self.a3(t)

# 1- eccentricities
e_b2 = 1 - xp.square(a2 / a1)
e_c2 = 1 - xp.square(a3 / a1)

# The potential at the origin
phi0 = v_c2 / (_log2 - 0.5 + (_log2 - 0.75) * (e_b2 + e_c2))

# The potential at the given position
r = xp.linalg.vector_norm(q, axis=-1)
u = r / r_s

# The functions F1, F2, and F3 and some useful quantities
log1pu = xp.log(1 + u)
u2 = u**2
um3 = u ** (-3)
costh2 = q[..., 2] ** 2 / r**2 # z^2 / r^2
sinthsinphi2 = q[..., 1] ** 2 / r**2 # (sin(theta) * sin(phi))^2
# Note that ꜛ is safer than computing the separate pieces, as it avoids
# x=y=0, z!=0, which would result in a NaN.

F1 = -log1pu / u
F2 = -1.0 / 3 + (2 * u2 - 3 * u + 6) / (6 * u2) + (1 / u - um3) * log1pu
F3 = (u2 - 3 * u - 6) / (2 * u2 * (1 + u)) + 3 * um3 * log1pu

# Select the output, r=0 is a special case.
out: gt.BatchFloatQScalar = phi0 * qlax.select(
u == 0,
xp.ones_like(u),
(
F1
+ (e_b2 + e_c2) / 2 * F2
+ (e_b2 * sinthsinphi2 + e_c2 * costh2) / 2 * F3
),
)
return out


# -------------------------------------------------------------------


@final
class NullPotential(AbstractPotential):
"""Null potential, i.e. no potential.
Expand Down
Loading

0 comments on commit a4054c8

Please sign in to comment.