Skip to content

Commit

Permalink
feat(pot): rename parameter c to r_s (GalacticDynamics#304)
Browse files Browse the repository at this point in the history
* feat(pot): rename parameter c to r_s

Standardising the scale radii

Signed-off-by: nstarman <nstarman@users.noreply.github.com>
  • Loading branch information
nstarman committed May 9, 2024
1 parent a0a58a7 commit bed9c40
Show file tree
Hide file tree
Showing 9 changed files with 75 additions and 33 deletions.
2 changes: 1 addition & 1 deletion src/galax/dynamics/_dynamics/integrate/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def __call__(
of motion. Here we will reproduce what happens with orbit integrations.
>>> pot = gp.HernquistPotential(m_tot=Quantity(1e12, "Msun"),
... c=Quantity(5, "kpc"), units="galactic")
... r_s=Quantity(5, "kpc"), units="galactic")
>>> integrator = gd.integrate.DiffraxIntegrator()
>>> t0, t1 = Quantity(0, "Gyr"), Quantity(1, "Gyr")
Expand Down
2 changes: 1 addition & 1 deletion src/galax/dynamics/_dynamics/integrate/_builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ def __call__(
of motion. Here we will reproduce what happens with orbit integrations.
>>> pot = gp.HernquistPotential(m_tot=Quantity(1e12, "Msun"),
... c=Quantity(5, "kpc"), units="galactic")
... r_s=Quantity(5, "kpc"), units="galactic")
>>> integrator = gd.integrate.DiffraxIntegrator()
>>> t0, t1 = Quantity(0, "Gyr"), Quantity(1, "Gyr")
Expand Down
17 changes: 9 additions & 8 deletions src/galax/potential/_potential/builtin/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,8 @@ class HernquistPotential(AbstractPotential):

m_tot: AbstractParameter = ParameterField(dimensions="mass") # type: ignore[assignment]
"""Total mass of the potential."""
c: AbstractParameter = ParameterField(dimensions="length") # type: ignore[assignment]
r_s: AbstractParameter = ParameterField(dimensions="length") # type: ignore[assignment]
"""Scale radius"""

_: KW_ONLY
units: AbstractUnitSystem = eqx.field(converter=unitsystem, static=True)
Expand All @@ -123,7 +124,7 @@ def _potential_energy( # TODO: inputs w/ units
self, q: gt.BatchQVec3, t: gt.BatchableRealQScalar, /
) -> gt.BatchFloatQScalar:
r = xp.linalg.vector_norm(q, axis=-1)
return -self.constants["G"] * self.m_tot(t) / (r + self.c(t))
return -self.constants["G"] * self.m_tot(t) / (r + self.r_s(t))


# -------------------------------------------------------------------
Expand Down Expand Up @@ -620,7 +621,7 @@ class TriaxialHernquistPotential(AbstractPotential):
:class:`~galax.potential.AbstractParameter` or an appropriate callable
or constant, like a Quantity. See
:class:`~galax.potential.ParameterField` for details.
c : :class:`~galax.potential.AbstractParameter`['length']
r_s : :class:`~galax.potential.AbstractParameter`['length']
A scale length that determines the concentration of the system. This
can be a :class:`~galax.potential.AbstractParameter` or an appropriate
callable or constant, like a Quantity. See
Expand All @@ -647,7 +648,7 @@ class TriaxialHernquistPotential(AbstractPotential):
>>> from galax.potential import TriaxialHernquistPotential
>>> pot = TriaxialHernquistPotential(m_tot=Quantity(1e12, "Msun"),
... c=Quantity(8, "kpc"), q1=1, q2=0.5,
... r_s=Quantity(8, "kpc"), q1=1, q2=0.5,
... units="galactic")
>>> q = Quantity([1, 0, 0], "kpc")
Expand All @@ -659,7 +660,7 @@ class TriaxialHernquistPotential(AbstractPotential):
m_tot: AbstractParameter = ParameterField(dimensions="mass") # type: ignore[assignment]
"""Mass of the potential."""

c: AbstractParameter = ParameterField(dimensions="length") # type: ignore[assignment]
r_s: AbstractParameter = ParameterField(dimensions="length") # type: ignore[assignment]
"""Scale a scale length that determines the concentration of the system."""

# TODO: move to a triaxial wrapper
Expand All @@ -685,8 +686,8 @@ class TriaxialHernquistPotential(AbstractPotential):
def _potential_energy( # TODO: inputs w/ units
self, q: gt.BatchQVec3, t: gt.BatchableRealQScalar, /
) -> gt.BatchFloatQScalar:
c, q1, q2 = self.c(t), self.q1(t), self.q2(t)
c = eqx.error_if(c, c.value <= 0, "c must be positive")
r_s, q1, q2 = self.r_s(t), self.q1(t), self.q2(t)
r_s = eqx.error_if(r_s, r_s.value <= 0, "r_s must be positive")

rprime = xp.sqrt(q[..., 0] ** 2 + (q[..., 1] / q1) ** 2 + (q[..., 2] / q2) ** 2)
return -self.constants["G"] * self.m_tot(t) / (rprime + c)
return -self.constants["G"] * self.m_tot(t) / (rprime + r_s)
6 changes: 3 additions & 3 deletions src/galax/potential/_potential/builtin/special.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ class LM10Potential(AbstractCompositePotential):
)
# TODO: as an actual `HernquistPotential`, then use `replace`?
_default_bulge: ClassVar[Mapping[str, Any]] = MappingProxyType(
{"m_tot": Quantity(3.4e10, "Msun"), "c": Quantity(0.7, "kpc")}
{"m_tot": Quantity(3.4e10, "Msun"), "r_s": Quantity(0.7, "kpc")}
)
# TODO: as an actual `LMJ09LogarithmicPotential`, then use `replace`?
_default_halo: ClassVar[Mapping[str, Any]] = MappingProxyType(
Expand Down Expand Up @@ -273,11 +273,11 @@ class MilkyWayPotential(AbstractCompositePotential):
)
# TODO: as an actual `HernquistPotential`, then use `replace`?
_default_bulge: ClassVar[MappingProxyType[str, Quantity]] = MappingProxyType(
{"m_tot": Quantity(5e9, "Msun"), "c": Quantity(1.0, "kpc")}
{"m_tot": Quantity(5e9, "Msun"), "r_s": Quantity(1.0, "kpc")}
)
# TODO: as an actual `HernquistPotential`, then use `replace`?
_default_nucleus: ClassVar[MappingProxyType[str, Quantity]] = MappingProxyType(
{"m_tot": Quantity(1.71e9, "Msun"), "c": Quantity(0.07, "kpc")}
{"m_tot": Quantity(1.71e9, "Msun"), "r_s": Quantity(0.07, "kpc")}
)

def __init__(
Expand Down
4 changes: 2 additions & 2 deletions src/galax/potential/_potential/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class PotentialFrame(AbstractPotentialBase):
Now we define a triaxial Hernquist potential with a time-dependent mass:
>>> mfunc = gp.UserParameter(lambda t: 1e12 * (1 + t.to_units_value("Gyr") / 10), unit="Msun")
>>> pot = gp.TriaxialHernquistPotential(m_tot=mfunc, c=Quantity(1, "kpc"),
>>> pot = gp.TriaxialHernquistPotential(m_tot=mfunc, r_s=Quantity(1, "kpc"),
... q1=1, q2=0.5, units="galactic")
Let's see the triaxiality of the potential:
Expand Down Expand Up @@ -171,7 +171,7 @@ class PotentialFrame(AbstractPotentialBase):
effect of the rotation more obvious:
>>> pot2 = gp.TriaxialHernquistPotential(m_tot=Quantity(1e12, "Msun"),
... c=Quantity(1, "kpc"), q1=0.1, q2=0.1, units="galactic")
... r_s=Quantity(1, "kpc"), q1=0.1, q2=0.1, units="galactic")
>>> op7 = gc.operators.ConstantRotationZOperator(Omega_z=Quantity(90, "deg/Gyr"))
>>> framedpot7 = gp.PotentialFrame(potential=pot2, operator=op7)
Expand Down
32 changes: 28 additions & 4 deletions src/galax/potential/_potential/io/_gala.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def gala_to_galax(pot: gp.PotentialBase, /) -> gpx.AbstractPotentialBase:
units=UnitSystem(kpc, Myr, solMass, rad),
constants=ImmutableDict({'G': ...}),
m_tot=ConstantParameter( unit=Unit("solMass"), value=Quantity[...](value=f64[], unit=Unit("solMass")) ),
c=ConstantParameter( unit=Unit("kpc"), value=Quantity[...](value=f64[], unit=Unit("kpc")) ) )
r_s=ConstantParameter( unit=Unit("kpc"), value=Quantity[...](value=f64[], unit=Unit("kpc")) ) )
Isochrone potential:
Expand Down Expand Up @@ -178,7 +178,6 @@ def _gala_to_galax_composite(pot: gp.CompositePotential, /) -> gpx.CompositePote


_GALA_TO_GALAX_REGISTRY: dict[type[gp.PotentialBase], type[gpx.AbstractPotential]] = {
gp.HernquistPotential: gpx.HernquistPotential,
gp.IsochronePotential: gpx.IsochronePotential,
gp.KeplerPotential: gpx.KeplerPotential,
gp.KuzminPotential: gpx.KuzminPotential,
Expand All @@ -188,7 +187,6 @@ def _gala_to_galax_composite(pot: gp.CompositePotential, /) -> gpx.CompositePote
}


@gala_to_galax.register(gp.HernquistPotential)
@gala_to_galax.register(gp.IsochronePotential)
@gala_to_galax.register(gp.KeplerPotential)
@gala_to_galax.register(gp.KuzminPotential)
Expand All @@ -198,7 +196,7 @@ def _gala_to_galax_composite(pot: gp.CompositePotential, /) -> gpx.CompositePote
def _gala_to_galax_registered(
gala: gp.PotentialBase, /
) -> gpx.AbstractPotential | gpx.PotentialFrame:
"""Convert a Gala HernquistPotential to a Galax potential."""
"""Convert a Gala potential to a Galax potential."""
if isinstance(gala.units, GalaDimensionlessUnitSystem):
msg = "Galax does not support converting dimensionless units."
raise TypeError(msg)
Expand Down Expand Up @@ -234,6 +232,32 @@ def _gala_to_galax_null(pot: gp.NullPotential, /) -> gpx.NullPotential:
return gpx.NullPotential(units=pot.units)


@gala_to_galax.register
def _gala_to_galax_hernquist(
gala: gp.HernquistPotential, /
) -> gpx.HernquistPotential | gpx.PotentialFrame:
r"""Convert a Gala HernquistPotential to a Galax potential.
Examples
--------
>>> import gala.potential as gp
>>> import gala.units as gu
>>> import galax.potential as gpx
>>> gpot = gp.HernquistPotential(m=1e11, c=20, units=gu.galactic)
>>> gpx.io.gala_to_galax(gpot)
HernquistPotential(
units=UnitSystem(kpc, Myr, solMass, rad),
constants=ImmutableDict({'G': ...}),
m_tot=ConstantParameter( ... ),
r_s=ConstantParameter( ... )
)
"""
params = gala.parameters
pot = gpx.HernquistPotential(m_tot=params["m"], r_s=params["c"], units=gala.units)
return _apply_frame(_get_frame(gala), pot)


@gala_to_galax.register
def _gala_to_galax_jaffe(
gala: gp.JaffePotential, /
Expand Down
8 changes: 4 additions & 4 deletions tests/unit/potential/builtin/misc/test_hernquist.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,23 +7,23 @@

import galax.typing as gt
from ...test_core import TestAbstractPotential as AbstractPotential_Test
from ..test_common import ParameterMTotMixin, ParameterShapeCMixin
from ..test_common import ParameterMTotMixin, ParameterScaleRadiusMixin
from galax.potential import AbstractPotentialBase, HernquistPotential


class TestHernquistPotential(
AbstractPotential_Test,
# Parameters
ParameterMTotMixin,
ParameterShapeCMixin,
ParameterScaleRadiusMixin,
):
@pytest.fixture(scope="class")
def pot_cls(self) -> type[HernquistPotential]:
return HernquistPotential

@pytest.fixture(scope="class")
def fields_(self, field_m_tot, field_c, field_units) -> dict[str, Any]:
return {"m_tot": field_m_tot, "c": field_c, "units": field_units}
def fields_(self, field_m_tot, field_r_s, field_units) -> dict[str, Any]:
return {"m_tot": field_m_tot, "r_s": field_r_s, "units": field_units}

# ==========================================================================

Expand Down
8 changes: 4 additions & 4 deletions tests/unit/potential/builtin/misc/test_triaxialhernquist.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from ...test_core import TestAbstractPotential as AbstractPotential_Test
from ..test_common import (
ParameterMTotMixin,
ParameterShapeCMixin,
ParameterScaleRadiusMixin,
ParameterShapeQ1Mixin,
ParameterShapeQ2Mixin,
)
Expand All @@ -20,7 +20,7 @@ class TestTriaxialHernquistPotential(
AbstractPotential_Test,
# Parameters
ParameterMTotMixin,
ParameterShapeCMixin,
ParameterScaleRadiusMixin,
ParameterShapeQ1Mixin,
ParameterShapeQ2Mixin,
):
Expand All @@ -30,11 +30,11 @@ def pot_cls(self) -> type[TriaxialHernquistPotential]:

@pytest.fixture(scope="class")
def fields_(
self, field_m_tot, field_c, field_q1, field_q2, field_units
self, field_m_tot, field_r_s, field_q1, field_q2, field_units
) -> dict[str, Any]:
return {
"m_tot": field_m_tot,
"c": field_c,
"r_s": field_r_s,
"q1": field_q1,
"q2": field_q2,
"units": field_units,
Expand Down
29 changes: 23 additions & 6 deletions tests/unit/potential/io/gala_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,6 @@ def _galax_to_gala_composite(pot: gpx.CompositePotential, /) -> gp.CompositePote
}


@galax_to_gala.register(gpx.HernquistPotential)
@galax_to_gala.register(gpx.IsochronePotential)
@galax_to_gala.register(gpx.KeplerPotential)
@galax_to_gala.register(gpx.KuzminPotential)
Expand Down Expand Up @@ -114,6 +113,20 @@ def _galax_to_gala_bar(pot: gpx.BarPotential, /) -> gp.PotentialBase:
raise NotImplementedError # TODO: implement


@galax_to_gala.register
def _galax_to_gala_hernquist(pot: gpx.HernquistPotential, /) -> gp.HernquistPotential:
"""Convert a Galax HernquistPotential to a Gala potential."""
if not _all_constant_parameters(pot, "m_tot", "r_s"):
msg = "Gala does not support time-dependent parameters."
raise TypeError(msg)

return gp.HernquistPotential(
m=convert(pot.m_tot(0), APYQuantity),
c=convert(pot.r_s(0), APYQuantity),
units=galax_to_gala_units(pot.units),
)


@galax_to_gala.register
def _galax_to_gala_jaffe(pot: gpx.JaffePotential, /) -> gp.JaffePotential:
"""Convert a Galax JaffePotential to a Gala potential."""
Expand Down Expand Up @@ -294,18 +307,20 @@ def rename(k: str) -> str:
def _galax_to_gala_lm10(pot: gpx.LM10Potential, /) -> gp.LM10Potential:
"""Convert a Galax LM10Potential to a Gala potential."""

def rename(k: str) -> str:
def rename(c: str, k: str) -> str:
match k:
case "m_tot":
return "m"
case "r_s":
case "r_s" if c == "halo":
return "r_h"
case "r_s" if c == "bulge":
return "c"
case _:
return k

return gp.LM10Potential(
**{
c: {rename(k): getattr(p, k)(0) for k in p.parameters}
c: {rename(c, k): getattr(p, k)(0) for k in p.parameters}
for c, p in pot.items()
}
)
Expand All @@ -315,16 +330,18 @@ def rename(k: str) -> str:
def _galax_to_gala_mwpotential(pot: gpx.MilkyWayPotential, /) -> gp.MilkyWayPotential:
"""Convert a Galax MilkyWayPotential to a Gala potential."""

def rename(k: str) -> str:
def rename(c: str, k: str) -> str:
match k:
case "m_tot":
return "m"
case "r_s" if c in ("bulge", "nucleus"):
return "c"
case _:
return k

return gp.MilkyWayPotential(
**{
c: {rename(k): getattr(p, k)(0) for k in p.parameters}
c: {rename(c, k): getattr(p, k)(0) for k in p.parameters}
for c, p in pot.items()
}
)

0 comments on commit bed9c40

Please sign in to comment.