From d72457cf3b255c7b1b245edb950df5990f99778c Mon Sep 17 00:00:00 2001 From: Nathaniel Starkman Date: Mon, 6 May 2024 11:20:31 -0400 Subject: [PATCH] feat(pot): LM10Potential and LMJ09logarithmic (#296) * feat(pot): LM10Potential and LMJ09logarithmic * refactor(pot): move around potentials and tests * fix: kuxmin gala compat * fix: longmuralibar compat Signed-off-by: nstarman --- src/galax/coordinates/_psp/base.py | 2 +- src/galax/dynamics/_dynamics/base.py | 2 +- .../dynamics/_dynamics/integrate/_funcs.py | 2 +- .../dynamics/_dynamics/mockstream/df/_base.py | 2 +- .../_dynamics/mockstream/df/_fardal.py | 1 - .../mockstream/mockstream_generator.py | 2 +- src/galax/potential/__init__.pyi | 16 +- .../potential/_potential/builtin/__init__.py | 6 +- .../_potential/builtin/logarithmic.py | 94 +++++++++ .../_potential/{ => builtin}/special.py | 116 ++++++++++- src/galax/potential/_potential/frame.py | 2 +- src/galax/potential/_potential/io/_gala.py | 117 ++++++++++- src/galax/potential/_potential/param/attr.py | 2 +- src/galax/potential/_potential/param/field.py | 2 +- tests/smoke/potential/test_package.py | 2 - .../logarithmic}/__init__.py | 0 .../builtin/logarithmic/test_common.py | 78 ++++++++ .../logarithmic/test_lmj09logarithmic.py | 173 +++++++++++++++++ .../builtin/logarithmic/test_logarithmic.py | 85 ++++++++ tests/unit/potential/builtin/misc/__init__.py | 0 .../potential/builtin/{ => misc}/test_bar.py | 4 +- .../builtin/{ => misc}/test_hernquist.py | 7 +- .../builtin/{ => misc}/test_isochrone.py | 4 +- .../builtin/{ => misc}/test_jaffe.py | 4 +- .../builtin/{ => misc}/test_kepler.py | 4 +- .../builtin/{ => misc}/test_kuzmin.py | 4 +- .../{ => misc}/test_leesutotriaxialnfw.py | 7 +- .../builtin/{ => misc}/test_longmuralibar.py | 6 +- .../builtin/{ => misc}/test_miyamotonagai.py | 4 +- .../potential/builtin/{ => misc}/test_null.py | 2 +- .../builtin/{ => misc}/test_plummer.py | 28 +-- .../builtin/{ => misc}/test_powerlawcutoff.py | 17 +- .../builtin/{ => misc}/test_satoh.py | 4 +- .../builtin/{ => misc}/test_stone.py | 4 +- .../{ => misc}/test_triaxialhernquist.py | 21 +- tests/unit/potential/builtin/nfw/__init__.py | 0 .../potential/builtin/{ => nfw}/test_nfw.py | 4 +- .../builtin/{ => nfw}/test_triaxialnfw.py | 4 +- .../{ => nfw}/test_vogelsberger08nfw.py | 6 +- .../potential/builtin/special/__init__.py | 0 .../special/test_bovymwpotential2014.py | 13 +- .../potential/builtin/special/test_lm10.py | 126 ++++++++++++ .../special/test_milkywaypotential.py | 2 +- tests/unit/potential/builtin/test_common.py | 23 +++ .../potential/builtin/test_logarithmichalo.py | 182 ------------------ tests/unit/potential/io/gala_helper.py | 64 +++++- tests/unit/potential/io/test_gala.py | 32 +-- tests/unit/potential/test_composite.py | 27 --- 48 files changed, 963 insertions(+), 344 deletions(-) create mode 100644 src/galax/potential/_potential/builtin/logarithmic.py rename src/galax/potential/_potential/{ => builtin}/special.py (65%) rename tests/unit/potential/{special => builtin/logarithmic}/__init__.py (100%) create mode 100644 tests/unit/potential/builtin/logarithmic/test_common.py create mode 100644 tests/unit/potential/builtin/logarithmic/test_lmj09logarithmic.py create mode 100644 tests/unit/potential/builtin/logarithmic/test_logarithmic.py create mode 100644 tests/unit/potential/builtin/misc/__init__.py rename tests/unit/potential/builtin/{ => misc}/test_bar.py (96%) rename tests/unit/potential/builtin/{ => misc}/test_hernquist.py (90%) rename tests/unit/potential/builtin/{ => misc}/test_isochrone.py (94%) rename tests/unit/potential/builtin/{ => misc}/test_jaffe.py (95%) rename tests/unit/potential/builtin/{ => misc}/test_kepler.py (95%) rename tests/unit/potential/builtin/{ => misc}/test_kuzmin.py (96%) rename tests/unit/potential/builtin/{ => misc}/test_leesutotriaxialnfw.py (96%) rename tests/unit/potential/builtin/{ => misc}/test_longmuralibar.py (96%) rename tests/unit/potential/builtin/{ => misc}/test_miyamotonagai.py (94%) rename tests/unit/potential/builtin/{ => misc}/test_null.py (98%) rename tests/unit/potential/builtin/{ => misc}/test_plummer.py (75%) rename tests/unit/potential/builtin/{ => misc}/test_powerlawcutoff.py (92%) rename tests/unit/potential/builtin/{ => misc}/test_satoh.py (94%) rename tests/unit/potential/builtin/{ => misc}/test_stone.py (96%) rename tests/unit/potential/builtin/{ => misc}/test_triaxialhernquist.py (80%) create mode 100644 tests/unit/potential/builtin/nfw/__init__.py rename tests/unit/potential/builtin/{ => nfw}/test_nfw.py (95%) rename tests/unit/potential/builtin/{ => nfw}/test_triaxialnfw.py (96%) rename tests/unit/potential/builtin/{ => nfw}/test_vogelsberger08nfw.py (96%) create mode 100644 tests/unit/potential/builtin/special/__init__.py rename tests/unit/potential/{ => builtin}/special/test_bovymwpotential2014.py (92%) create mode 100644 tests/unit/potential/builtin/special/test_lm10.py rename tests/unit/potential/{ => builtin}/special/test_milkywaypotential.py (98%) delete mode 100644 tests/unit/potential/builtin/test_logarithmichalo.py diff --git a/src/galax/coordinates/_psp/base.py b/src/galax/coordinates/_psp/base.py index 849c53da..5441d750 100644 --- a/src/galax/coordinates/_psp/base.py +++ b/src/galax/coordinates/_psp/base.py @@ -24,7 +24,7 @@ if TYPE_CHECKING: from typing import Self - from galax.potential._potential.base import AbstractPotentialBase + from galax.potential import AbstractPotentialBase class ComponentShapeTuple(NamedTuple): diff --git a/src/galax/dynamics/_dynamics/base.py b/src/galax/dynamics/_dynamics/base.py index 86ca3f42..38bae8e5 100644 --- a/src/galax/dynamics/_dynamics/base.py +++ b/src/galax/dynamics/_dynamics/base.py @@ -25,7 +25,7 @@ _q_converter, getitem_vec1time_index, ) -from galax.potential._potential.base import AbstractPotentialBase +from galax.potential import AbstractPotentialBase from galax.typing import BatchFloatQScalar, QVec1, QVecTime from galax.utils._shape import batched_shape, vector_batched_shape diff --git a/src/galax/dynamics/_dynamics/integrate/_funcs.py b/src/galax/dynamics/_dynamics/integrate/_funcs.py index 9bb61fc1..73bb4bde 100644 --- a/src/galax/dynamics/_dynamics/integrate/_funcs.py +++ b/src/galax/dynamics/_dynamics/integrate/_funcs.py @@ -18,7 +18,7 @@ from ._builtin import DiffraxIntegrator from galax.coordinates import PhaseSpacePosition from galax.dynamics._dynamics.orbit import InterpolatedOrbit, Orbit -from galax.potential._potential.base import AbstractPotentialBase +from galax.potential import AbstractPotentialBase ############################################################################## diff --git a/src/galax/dynamics/_dynamics/mockstream/df/_base.py b/src/galax/dynamics/_dynamics/mockstream/df/_base.py index 21f6ff7a..bc85b70d 100644 --- a/src/galax/dynamics/_dynamics/mockstream/df/_base.py +++ b/src/galax/dynamics/_dynamics/mockstream/df/_base.py @@ -19,7 +19,7 @@ from ._progenitor import ConstantMassProtenitor, ProgenitorMassCallable from galax.dynamics._dynamics.mockstream.core import MockStream from galax.dynamics._dynamics.orbit import Orbit -from galax.potential._potential.base import AbstractPotentialBase +from galax.potential import AbstractPotentialBase Wif: TypeAlias = tuple[ gt.LengthVec3, diff --git a/src/galax/dynamics/_dynamics/mockstream/df/_fardal.py b/src/galax/dynamics/_dynamics/mockstream/df/_fardal.py index cbe4a722..bba64efa 100644 --- a/src/galax/dynamics/_dynamics/mockstream/df/_fardal.py +++ b/src/galax/dynamics/_dynamics/mockstream/df/_fardal.py @@ -18,7 +18,6 @@ import galax.typing as gt from ._base import AbstractStreamDF from galax.potential import AbstractPotentialBase -from galax.potential._potential.base import AbstractPotentialBase # ============================================================ # Constants diff --git a/src/galax/dynamics/_dynamics/mockstream/mockstream_generator.py b/src/galax/dynamics/_dynamics/mockstream/mockstream_generator.py index 1c028698..9cefe412 100644 --- a/src/galax/dynamics/_dynamics/mockstream/mockstream_generator.py +++ b/src/galax/dynamics/_dynamics/mockstream/mockstream_generator.py @@ -25,7 +25,7 @@ _default_integrator, evaluate_orbit, ) -from galax.potential._potential.base import AbstractPotentialBase +from galax.potential import AbstractPotentialBase Carry: TypeAlias = tuple[gt.IntScalar, gt.VecN, gt.VecN] diff --git a/src/galax/potential/__init__.pyi b/src/galax/potential/__init__.pyi index 113598b5..b3dbd157 100644 --- a/src/galax/potential/__init__.pyi +++ b/src/galax/potential/__init__.pyi @@ -18,7 +18,6 @@ __all__ = [ "JaffePotential", "KeplerPotential", "KuzminPotential", - "LogarithmicPotential", "LongMuraliBarPotential", "MiyamotoNagaiPotential", "NullPotential", @@ -27,6 +26,9 @@ __all__ = [ "SatohPotential", "StoneOstriker15Potential", "TriaxialHernquistPotential", + # logarithmic + "LogarithmicPotential", + "LMJ09LogarithmicPotential", # nfw "NFWPotential", "LeeSutoTriaxialNFWPotential", @@ -34,6 +36,7 @@ __all__ = [ "Vogelsberger08TriaxialNFWPotential", # special "BovyMWPotential2014", + "LM10Potential", "MilkyWayPotential", # frame "PotentialFrame", @@ -63,7 +66,6 @@ from ._potential.builtin.builtin import ( JaffePotential, KeplerPotential, KuzminPotential, - LogarithmicPotential, LongMuraliBarPotential, MiyamotoNagaiPotential, NullPotential, @@ -73,12 +75,21 @@ from ._potential.builtin.builtin import ( StoneOstriker15Potential, TriaxialHernquistPotential, ) +from ._potential.builtin.logarithmic import ( + LMJ09LogarithmicPotential, + LogarithmicPotential, +) from ._potential.builtin.nfw import ( LeeSutoTriaxialNFWPotential, NFWPotential, TriaxialNFWPotential, Vogelsberger08TriaxialNFWPotential, ) +from ._potential.builtin.special import ( + BovyMWPotential2014, + LM10Potential, + MilkyWayPotential, +) from ._potential.composite import AbstractCompositePotential, CompositePotential from ._potential.core import AbstractPotential from ._potential.frame import PotentialFrame @@ -99,4 +110,3 @@ from ._potential.param import ( ParametersAttribute, UserParameter, ) -from ._potential.special import BovyMWPotential2014, MilkyWayPotential diff --git a/src/galax/potential/_potential/builtin/__init__.py b/src/galax/potential/_potential/builtin/__init__.py index a59197d5..cf76fe0e 100644 --- a/src/galax/potential/_potential/builtin/__init__.py +++ b/src/galax/potential/_potential/builtin/__init__.py @@ -1,10 +1,14 @@ """``galax`` Potentials.""" # ruff:noqa: F401 -from . import builtin, nfw +from . import builtin, logarithmic, nfw, special from .builtin import * +from .logarithmic import * from .nfw import * +from .special import * __all__: list[str] = [] __all__ += builtin.__all__ +__all__ += logarithmic.__all__ __all__ += nfw.__all__ +__all__ += special.__all__ diff --git a/src/galax/potential/_potential/builtin/logarithmic.py b/src/galax/potential/_potential/builtin/logarithmic.py new file mode 100644 index 00000000..49642221 --- /dev/null +++ b/src/galax/potential/_potential/builtin/logarithmic.py @@ -0,0 +1,94 @@ +"""galax: Galactic Dynamix in Jax.""" + +__all__ = [ + "LogarithmicPotential", + "LMJ09LogarithmicPotential", +] + +from dataclasses import KW_ONLY +from functools import partial +from typing import final + +import equinox as eqx +import jax + +import quaxed.array_api as xp +from unxt import AbstractUnitSystem, Quantity, unitsystem + +import galax.typing as gt +from galax.potential._potential.base import default_constants +from galax.potential._potential.core import AbstractPotential +from galax.potential._potential.param import AbstractParameter, ParameterField +from galax.utils import ImmutableDict + + +@final +class LogarithmicPotential(AbstractPotential): + """Logarithmic Potential.""" + + v_c: AbstractParameter = ParameterField(dimensions="speed") # type: ignore[assignment] + r_s: AbstractParameter = ParameterField(dimensions="length") # type: ignore[assignment] + + _: KW_ONLY + units: AbstractUnitSystem = eqx.field(converter=unitsystem, static=True) + constants: ImmutableDict[Quantity] = eqx.field( + default=default_constants, converter=ImmutableDict + ) + + @partial(jax.jit) + def _potential_energy( + self, q: gt.BatchQVec3, t: gt.BatchableRealQScalar, / + ) -> gt.BatchFloatQScalar: + r2 = xp.linalg.vector_norm(q, axis=-1).to_value(self.units["length"]) ** 2 + return ( + 0.5 + * self.v_c(t) ** 2 + * xp.log(self.r_s(t).to_value(self.units["length"]) ** 2 + r2) + ) + + +@final +class LMJ09LogarithmicPotential(AbstractPotential): + """Logarithmic Potential from LMJ09. + + https://ui.adsabs.harvard.edu/abs/2009ApJ...703L..67L/abstract + """ + + v_c: AbstractParameter = ParameterField(dimensions="speed") # type: ignore[assignment] + r_s: AbstractParameter = ParameterField(dimensions="length") # type: ignore[assignment] + + q1: AbstractParameter = ParameterField(dimensions="dimensionless") # type: ignore[assignment] + q2: AbstractParameter = ParameterField(dimensions="dimensionless") # type: ignore[assignment] + q3: AbstractParameter = ParameterField(dimensions="dimensionless") # type: ignore[assignment] + + phi: AbstractParameter = ParameterField(dimensions="angle") # type: ignore[assignment] + + _: KW_ONLY + units: AbstractUnitSystem = eqx.field(converter=unitsystem, static=True) + constants: ImmutableDict[Quantity] = eqx.field( + default=default_constants, converter=ImmutableDict + ) + + @partial(jax.jit) + def _potential_energy( + self, q: gt.BatchQVec3, t: gt.BatchableRealQScalar, / + ) -> gt.BatchFloatQScalar: + # Load parameters + q1, q2, q3 = self.q1(t), self.q2(t), self.q3(t) + phi = self.phi(t) + + # Rotated and scaled coordinates + sphi, cphi = xp.sin(phi), xp.cos(phi) + x = q[..., 0] * cphi + q[..., 1] * sphi + y = -q[..., 0] * sphi + q[..., 1] * cphi + r2 = (x / q1) ** 2 + (y / q2) ** 2 + (q[..., 2] / q3) ** 2 + + # Potential energy + return ( + 0.5 + * self.v_c(t) ** 2 + * xp.log( + self.r_s(t).to_value(self.units["length"]) ** 2 + + r2.to_value(self.units["area"]) + ) + ) diff --git a/src/galax/potential/_potential/special.py b/src/galax/potential/_potential/builtin/special.py similarity index 65% rename from src/galax/potential/_potential/special.py rename to src/galax/potential/_potential/builtin/special.py index d4996004..c7ddda60 100644 --- a/src/galax/potential/_potential/special.py +++ b/src/galax/potential/_potential/builtin/special.py @@ -1,6 +1,10 @@ """galax: Galactic Dynamix in Jax.""" -__all__ = ["BovyMWPotential2014", "MilkyWayPotential"] +__all__ = [ + "BovyMWPotential2014", + "LM10Potential", + "MilkyWayPotential", +] from collections.abc import Mapping @@ -10,17 +14,15 @@ import equinox as eqx +import quaxed.array_api as xp from unxt import Quantity from unxt.unitsystems import AbstractUnitSystem, dimensionless, galactic, unitsystem -from .base import AbstractPotentialBase, default_constants -from .builtin.builtin import ( - HernquistPotential, - MiyamotoNagaiPotential, - PowerLawCutoffPotential, -) -from .builtin.nfw import NFWPotential -from .composite import AbstractCompositePotential +from .builtin import HernquistPotential, MiyamotoNagaiPotential, PowerLawCutoffPotential +from .logarithmic import LMJ09LogarithmicPotential +from .nfw import NFWPotential +from galax.potential._potential.base import AbstractPotentialBase, default_constants +from galax.potential._potential.composite import AbstractCompositePotential from galax.utils import ImmutableDict T = TypeVar("T", bound=AbstractPotentialBase) @@ -123,6 +125,102 @@ def __init__( ) +_sqrt2 = xp.sqrt(xp.asarray(2.0)) + + +@final +class LM10Potential(AbstractCompositePotential): + """Law & Majewski (2010) Milky Way mass model. + + The Galactic potential used by Law and Majewski (2010) to represent the + Milky Way as a three-component sum of disk, bulge, and halo. + + The disk potential is an axisymmetric + :class:`~galax.potential.MiyamotoNagaiPotential`, the bulge potential is a + spherical :class:`~galax.potential.HernquistPotential`, and the halo + potential is a triaxial :class:`~galax.potential.LMJ09LogarithmicPotential`. + + Default parameters are fixed to those found in LM10 by fitting N-body + simulations to the Sagittarius stream. + + Parameters + ---------- + units : `~galax.units.UnitSystem` (optional) + Set of non-reducable units that specify (at minimum) the length, mass, + time, and angle units. + disk : dict (optional) + Parameters to be passed to the + :class:`~galax.potential.MiyamotoNagaiPotential`. + bulge : dict (optional) + Parameters to be passed to the + :class:`~galax.potential.HernquistPotential`. + halo : dict (optional) + Parameters to be passed to the + :class:`~galax.potential.LMJ09LogarithmicPotential`. + + Note: in subclassing, order of arguments must match order of potential + components added at bottom of init. + """ + + _data: dict[str, AbstractPotentialBase] = eqx.field(init=False) + _: KW_ONLY + units: AbstractUnitSystem = eqx.field( + default=galactic, static=True, converter=unitsystem + ) + constants: ImmutableDict[Quantity] = eqx.field( + default=default_constants, converter=ImmutableDict + ) + + # TODO: as an actual `MiyamotoNagaiPotential`, then use `replace`? + _default_disk: ClassVar[Mapping[str, Any]] = MappingProxyType( + { + "m_tot": Quantity(1e11, "Msun"), + "a": Quantity(6.5, "kpc"), + "b": Quantity(0.26, "kpc"), + } + ) + # TODO: as an actual `HernquistPotential`, then use `replace`? + _default_bulge: ClassVar[Mapping[str, Any]] = MappingProxyType( + {"m_tot": Quantity(3.4e10, "Msun"), "c": Quantity(0.7, "kpc")} + ) + # TODO: as an actual `LMJ09LogarithmicPotential`, then use `replace`? + _default_halo: ClassVar[Mapping[str, Any]] = MappingProxyType( + { + "v_c": Quantity(_sqrt2 * 121.858, "km / s"), + "r_s": Quantity(12.0, "kpc"), + "q1": 1.38, + "q2": 1.0, + "q3": 1.36, + "phi": Quantity(97, "degree"), + } + ) + + def __init__( + self, + *, + disk: MiyamotoNagaiPotential | Mapping[str, Any] | None = None, + bulge: HernquistPotential | Mapping[str, Any] | None = None, + halo: LMJ09LogarithmicPotential | Mapping[str, Any] | None = None, + units: Any = galactic, + constants: Any = default_constants, + ) -> None: + units_ = unitsystem(units) if units is not None else galactic + + super().__init__( + disk=_parse_input_comp( + MiyamotoNagaiPotential, disk, self._default_disk, units_ + ), + bulge=_parse_input_comp( + HernquistPotential, bulge, self._default_bulge, units_ + ), + halo=_parse_input_comp( + LMJ09LogarithmicPotential, halo, self._default_halo, units_ + ), + units=units_, + constants=constants, + ) + + @final class MilkyWayPotential(AbstractCompositePotential): """Milky Way mass model. diff --git a/src/galax/potential/_potential/frame.py b/src/galax/potential/_potential/frame.py index 043dc239..ffc7e262 100644 --- a/src/galax/potential/_potential/frame.py +++ b/src/galax/potential/_potential/frame.py @@ -12,7 +12,7 @@ from unxt import AbstractUnitSystem, Quantity import galax.typing as gt -from galax.potential._potential.base import AbstractPotentialBase +from .base import AbstractPotentialBase from galax.utils import ImmutableDict diff --git a/src/galax/potential/_potential/io/_gala.py b/src/galax/potential/_potential/io/_gala.py index 63cd6ec7..a2ce710f 100644 --- a/src/galax/potential/_potential/io/_gala.py +++ b/src/galax/potential/_potential/io/_gala.py @@ -182,7 +182,6 @@ def _gala_to_galax_composite(pot: gp.CompositePotential, /) -> gpx.CompositePote gp.IsochronePotential: gpx.IsochronePotential, gp.KeplerPotential: gpx.KeplerPotential, gp.KuzminPotential: gpx.KuzminPotential, - gp.LogarithmicPotential: gpx.LogarithmicPotential, gp.MiyamotoNagaiPotential: gpx.MiyamotoNagaiPotential, gp.PlummerPotential: gpx.PlummerPotential, gp.PowerLawCutoffPotential: gpx.PowerLawCutoffPotential, @@ -192,6 +191,7 @@ def _gala_to_galax_composite(pot: gp.CompositePotential, /) -> gpx.CompositePote @gala_to_galax.register(gp.HernquistPotential) @gala_to_galax.register(gp.IsochronePotential) @gala_to_galax.register(gp.KeplerPotential) +@gala_to_galax.register(gp.KuzminPotential) @gala_to_galax.register(gp.MiyamotoNagaiPotential) @gala_to_galax.register(gp.PlummerPotential) @gala_to_galax.register(gp.PowerLawCutoffPotential) @@ -260,6 +260,42 @@ def _gala_to_galax_jaffe( return _apply_frame(_get_frame(gala), pot) +@gala_to_galax.register +def _gala_to_galax_longmuralibar( + gala: gp.LongMuraliBarPotential, / +) -> gpx.LongMuraliBarPotential | gpx.PotentialFrame: + """Convert a Gala LongMuraliBarPotential to a Galax potential. + + Examples + -------- + >>> import gala.potential as gp + >>> import gala.units as gu + >>> import galax.potential as gpx + + >>> gpot = gp.LongMuraliBarPotential(m=1e11, a=20, b=10, c=5, units=gu.galactic) + >>> gpx.io.gala_to_galax(gpot) + LongMuraliBarPotential( + units=UnitSystem(kpc, Myr, solMass, rad), + constants=ImmutableDict({'G': Quantity...}), + m_tot=ConstantParameter( ... ), + a=ConstantParameter( ... ), + b=ConstantParameter( ... ), + c=ConstantParameter( ... ), + alpha=ConstantParameter( ... ) + ) + """ + params = gala.parameters + pot = gpx.LongMuraliBarPotential( + m_tot=params["m"], + a=params["a"], + b=params["b"], + c=params["c"], + alpha=params["alpha"], + units=gala.units, + ) + return _apply_frame(_get_frame(gala), pot) + + @gala_to_galax.register def _gala_to_galax_satoh( gala: gp.SatohPotential, / @@ -318,6 +354,62 @@ def _gala_to_galax_stoneostriker15( return _apply_frame(_get_frame(gala), pot) +# ----------------------------------------------------------------------------- +# Logarithmic potentials + + +@gala_to_galax.register +def _gala_to_galax_logarithmic( + gala: gp.LogarithmicPotential, / +) -> gpx.LogarithmicPotential | gpx.LMJ09LogarithmicPotential | gpx.PotentialFrame: + """Convert a Gala LogarithmicPotential to a Galax potential. + + If the flattening or rotation 'phi' is non-zero, the potential is a + :class:`galax.potential.LMJ09LogarithmicPotential` (or + :class:`galax.potential.PotentialFrame` wrapper thereof). Otherwise, it is a + :class:`galax.potential.LogarithmicPotential` (or + :class:`galax.potential.PotentialFrame` wrapper thereof). + + Examples + -------- + >>> import gala.potential as gp + >>> import gala.units as gu + >>> import galax.potential as gpx + + >>> gpot = gp.LogarithmicPotential(v_c=220, r_h=20, units=gu.galactic) + >>> gpx.io.gala_to_galax(gpot) + LogarithmicPotential( + units=UnitSystem(kpc, Myr, solMass, rad), + constants=ImmutableDict({'G': ...}), + v_c=ConstantParameter( ... ), + r_s=ConstantParameter( ... ) + ) + """ + params = gala.parameters + + if ( + params["q1"] != 1 + or params["q2"] != 1 + or params["q3"] != 1 + or params["phi"] != 0 + ): + pot = gpx.LMJ09LogarithmicPotential( + v_c=params["v_c"], + r_s=params["r_h"], + q1=params["q1"], + q2=params["q2"], + q3=params["q3"], + phi=params["phi"], + units=gala.units, + ) + else: + pot = gpx.LogarithmicPotential( + v_c=params["v_c"], r_s=params["r_h"], units=gala.units + ) + + return _apply_frame(_get_frame(gala), pot) + + # ----------------------------------------------------------------------------- # NFW potentials @@ -427,6 +519,29 @@ def _gala_to_galax_bovymw2014( ) +@gala_to_galax.register +def _gala_to_galax_lm10(pot: gp.LM10Potential, /) -> gpx.LM10Potential: + """Convert a Gala LM10Potential to a Galax potential. + + Examples + -------- + >>> import gala.potential as gp + >>> import galax.potential as gpx + + >>> gpot = gp.LM10Potential() + >>> gpx.io.gala_to_galax(gpot) + LM10Potential({'disk': MiyamotoNagaiPotential( ... ), + 'bulge': HernquistPotential( ... ), + 'halo': LMJ09LogarithmicPotential( ... )}) + + """ + return gpx.LM10Potential( + disk=gala_to_galax(pot["disk"]), + bulge=gala_to_galax(pot["bulge"]), + halo=gala_to_galax(pot["halo"]), + ) + + @gala_to_galax.register def _gala_to_galax_mw(pot: gp.MilkyWayPotential, /) -> gpx.MilkyWayPotential: """Convert a Gala MilkyWayPotential to a Galax potential. diff --git a/src/galax/potential/_potential/param/attr.py b/src/galax/potential/_potential/param/attr.py index 53f0c985..1d77caca 100644 --- a/src/galax/potential/_potential/param/attr.py +++ b/src/galax/potential/_potential/param/attr.py @@ -9,7 +9,7 @@ from .field import ParameterField if TYPE_CHECKING: - from galax.potential._potential.base import AbstractPotentialBase + from galax.potential import AbstractPotentialBase @final diff --git a/src/galax/potential/_potential/param/field.py b/src/galax/potential/_potential/param/field.py index e4eea56f..5055e936 100644 --- a/src/galax/potential/_potential/param/field.py +++ b/src/galax/potential/_potential/param/field.py @@ -27,7 +27,7 @@ from galax.utils.dataclasses import Sentinel, dataclass_with_converter, field if TYPE_CHECKING: - from galax.potential._potential.base import AbstractPotentialBase + from galax.potential import AbstractPotentialBase def converter_parameter(value: Any) -> AbstractParameter: diff --git a/tests/smoke/potential/test_package.py b/tests/smoke/potential/test_package.py index db5d6370..6f356c6c 100644 --- a/tests/smoke/potential/test_package.py +++ b/tests/smoke/potential/test_package.py @@ -7,7 +7,6 @@ frame, funcs, param, - special, ) @@ -21,7 +20,6 @@ def test_all() -> None: *composite.__all__, *core.__all__, *param.__all__, - *special.__all__, *frame.__all__, *funcs.__all__, } diff --git a/tests/unit/potential/special/__init__.py b/tests/unit/potential/builtin/logarithmic/__init__.py similarity index 100% rename from tests/unit/potential/special/__init__.py rename to tests/unit/potential/builtin/logarithmic/__init__.py diff --git a/tests/unit/potential/builtin/logarithmic/test_common.py b/tests/unit/potential/builtin/logarithmic/test_common.py new file mode 100644 index 00000000..d3e85666 --- /dev/null +++ b/tests/unit/potential/builtin/logarithmic/test_common.py @@ -0,0 +1,78 @@ +import astropy.units as u +import pytest + +import quaxed.numpy as qnp +from unxt import Quantity +from unxt.unitsystems import galactic + +import galax.potential as gp +from ...param.test_field import ParameterFieldMixin +from galax.potential import ConstantParameter + + +class ParameterVCMixin(ParameterFieldMixin): + """Test the circular velocity parameter.""" + + pot_cls: type[gp.AbstractPotential] + + @pytest.fixture(scope="class") + def field_v_c(self) -> Quantity["speed"]: + return Quantity(220, "km/s") + + # ===================================================== + + def test_v_c_units(self, pot_cls, fields): + """Test the speed parameter.""" + fields["v_c"] = Quantity(1.0, u.Unit(220 * u.km / u.s)) + fields["units"] = galactic + pot = pot_cls(**fields) + assert isinstance(pot.v_c, ConstantParameter) + assert pot.v_c.value == Quantity(220, "km/s") + + def test_v_c_constant(self, pot_cls, fields): + """Test the speed parameter.""" + fields["v_c"] = Quantity(1.0, "km/s") + pot = pot_cls(**fields) + assert pot.v_c(t=0) == Quantity(1.0, "km/s") + + @pytest.mark.xfail(reason="TODO: user function doesn't have units") + def test_v_c_userfunc(self, pot_cls, fields): + """Test the mass parameter.""" + fields["v_c"] = lambda t: t + 2 + pot = pot_cls(**fields) + assert pot.v_c(t=0) == 2 + + +class ParameterRSMixin(ParameterFieldMixin): + """Test the scale radius parameter.""" + + pot_cls: type[gp.AbstractPotential] + + @pytest.fixture(scope="class") + def field_r_s(self) -> Quantity["length"]: + return Quantity(8, "kpc") + + # ===================================================== + + def test_r_s_units(self, pot_cls, fields): + """Test the speed parameter.""" + fields["r_s"] = Quantity(1, u.Unit(10 * u.kpc)) + fields["units"] = galactic + pot = pot_cls(**fields) + assert isinstance(pot.r_s, ConstantParameter) + assert qnp.isclose( + pot.r_s.value, Quantity(10, "kpc"), atol=Quantity(1e-15, "kpc") + ) + + def test_r_s_constant(self, pot_cls, fields): + """Test the speed parameter.""" + fields["r_s"] = Quantity(11.0, "kpc") + pot = pot_cls(**fields) + assert pot.r_s(t=0) == Quantity(11.0, "kpc") + + @pytest.mark.xfail(reason="TODO: user function doesn't have units") + def test_r_s_userfunc(self, pot_cls, fields): + """Test the mass parameter.""" + fields["r_s"] = lambda t: t + 2 + pot = pot_cls(**fields) + assert pot.r_s(t=0) == 2 diff --git a/tests/unit/potential/builtin/logarithmic/test_lmj09logarithmic.py b/tests/unit/potential/builtin/logarithmic/test_lmj09logarithmic.py new file mode 100644 index 00000000..8ad1cc7a --- /dev/null +++ b/tests/unit/potential/builtin/logarithmic/test_lmj09logarithmic.py @@ -0,0 +1,173 @@ +from typing import Any + +import astropy.units as u +import pytest + +import quaxed.numpy as qnp +from unxt import AbstractUnitSystem, Quantity + +import galax.potential as gp +import galax.typing as gt +from ...param.test_field import ParameterFieldMixin +from ...test_core import TestAbstractPotential as AbstractPotential_Test +from ..test_common import ( + ParameterShapeQ1Mixin, + ParameterShapeQ2Mixin, + ParameterShapeQ3Mixin, +) +from .test_common import ParameterRSMixin, ParameterVCMixin +from galax.potential import ( + AbstractPotentialBase, + ConstantParameter, + LMJ09LogarithmicPotential, +) +from galax.utils._optional_deps import HAS_GALA + + +class ParameterPhiMixin(ParameterFieldMixin): + """Test the phi parameter.""" + + pot_cls: type[gp.AbstractPotential] + + @pytest.fixture(scope="class") + def field_phi(self) -> Quantity["angle"]: + return Quantity(220, "deg") + + # ===================================================== + + def test_phi_units(self, pot_cls, fields): + """Test the speed parameter.""" + fields["phi"] = Quantity(1.0, u.Unit(220 * u.deg)) + pot = pot_cls(**fields) + assert isinstance(pot.phi, ConstantParameter) + assert pot.phi.value == Quantity(220, "deg") + + def test_phi_constant(self, pot_cls, fields): + """Test the speed parameter.""" + fields["phi"] = Quantity(1.0, "deg") + pot = pot_cls(**fields) + assert pot.phi(t=0) == Quantity(1.0, "deg") + + @pytest.mark.xfail(reason="TODO: user function doesn't have units") + def test_phi_userfunc(self, pot_cls, fields): + """Test the mass parameter.""" + fields["phi"] = lambda t: t + 2 + pot = pot_cls(**fields) + assert pot.phi(t=0) == 2 + + +class TestLMJ09LogarithmicPotential( + AbstractPotential_Test, + # Parameters + ParameterVCMixin, + ParameterRSMixin, + ParameterShapeQ1Mixin, + ParameterShapeQ2Mixin, + ParameterShapeQ3Mixin, + ParameterPhiMixin, +): + """Test the `galax.potential.LMJ09LogarithmicPotential` class.""" + + @pytest.fixture(scope="class") + def pot_cls(self) -> type[gp.LMJ09LogarithmicPotential]: + return gp.LMJ09LogarithmicPotential + + @pytest.fixture(scope="class") + def fields_( + self, + field_v_c: u.Quantity, + field_r_s: u.Quantity, + field_q1: u.Quantity, + field_q2: u.Quantity, + field_q3: u.Quantity, + field_phi: u.Quantity, + field_units: AbstractUnitSystem, + ) -> dict[str, Any]: + return { + "v_c": field_v_c, + "r_s": field_r_s, + "q1": field_q1, + "q2": field_q2, + "q3": field_q3, + "phi": field_phi, + "units": field_units, + } + + # ========================================================================== + + def test_potential_energy( + self, pot: LMJ09LogarithmicPotential, x: gt.QVec3 + ) -> None: + expect = Quantity(0.11819267, unit="kpc2 / Myr2") + assert qnp.isclose( + pot.potential_energy(x, t=0), expect, atol=Quantity(1e-8, expect.unit) + ) + + def test_gradient(self, pot: LMJ09LogarithmicPotential, x: gt.QVec3) -> None: + expect = Quantity([-0.00046885, 0.00181093, 0.00569646], "kpc / Myr2") + assert qnp.allclose( + pot.gradient(x, t=0), expect, atol=Quantity(1e-8, expect.unit) + ) + + def test_density(self, pot: LMJ09LogarithmicPotential, x: gt.QVec3) -> None: + expect = Quantity(48995543.34035844, "solMass / kpc3") + assert qnp.isclose( + pot.density(x, t=0), expect, atol=Quantity(1e-8, expect.unit) + ) + + def test_hessian(self, pot: LMJ09LogarithmicPotential, x: gt.QVec3) -> None: + expect = Quantity( + [ + [0.00100608, -0.00070826, 0.00010551], + [-0.00070826, 0.00114681, -0.00040755], + [0.00010551, -0.00040755, 0.00061682], + ], + "1/Myr2", + ) + assert qnp.allclose( + pot.hessian(x, t=0), expect, atol=Quantity(1e-8, expect.unit) + ) + + # --------------------------------- + # Convenience methods + + def test_tidal_tensor(self, pot: AbstractPotentialBase, x: gt.QVec3) -> None: + """Test the `AbstractPotentialBase.tidal_tensor` method.""" + expect = Quantity( + [ + [8.28469691e-05, -7.08263497e-04, 1.05514716e-04], + [-7.08263497e-04, 2.23569293e-04, -4.07553647e-04], + [1.05514716e-04, -4.07553647e-04, -3.06416262e-04], + ], + "1/Myr2", + ) + assert qnp.allclose( + pot.tidal_tensor(x, t=0), expect, atol=Quantity(1e-8, expect.unit) + ) + + # ========================================================================== + # Interoperability + + @pytest.mark.skipif(not HAS_GALA, reason="requires gala") + @pytest.mark.parametrize( + ("method0", "method1", "atol"), + [ + ("potential_energy", "energy", 1e-8), + ("gradient", "gradient", 1e-8), + # ("density", "density", 1e-8), # TODO: get gala and galax to agree + # ("hessian", "hessian", 1e-8), # TODO: get gala and galax to agree + ], + ) + def test_method_gala( + self, + pot: gp.AbstractPotentialBase, + method0: str, + method1: str, + x: gt.QVec3, + atol: float, + ) -> None: + """Test the equivalence of methods between gala and galax. + + This test only runs if the potential can be mapped to gala. + """ + super().test_method_gala(pot, method0, method1, x, atol) diff --git a/tests/unit/potential/builtin/logarithmic/test_logarithmic.py b/tests/unit/potential/builtin/logarithmic/test_logarithmic.py new file mode 100644 index 00000000..20cd608f --- /dev/null +++ b/tests/unit/potential/builtin/logarithmic/test_logarithmic.py @@ -0,0 +1,85 @@ +from typing import Any + +import astropy.units as u +import pytest + +import quaxed.numpy as qnp +from unxt import AbstractUnitSystem, Quantity + +import galax.potential as gp +import galax.typing as gt +from ...test_core import TestAbstractPotential as AbstractPotential_Test +from .test_common import ParameterRSMixin, ParameterVCMixin +from galax.potential import AbstractPotentialBase, LogarithmicPotential + + +class TestLogarithmicPotential( + AbstractPotential_Test, + # Parameters + ParameterVCMixin, + ParameterRSMixin, +): + """Test the `galax.potential.LogarithmicPotential` class.""" + + @pytest.fixture(scope="class") + def pot_cls(self) -> type[gp.LogarithmicPotential]: + return gp.LogarithmicPotential + + @pytest.fixture(scope="class") + def fields_( + self, + field_v_c: u.Quantity, + field_r_s: u.Quantity, + field_units: AbstractUnitSystem, + ) -> dict[str, Any]: + return {"v_c": field_v_c, "r_s": field_r_s, "units": field_units} + + # ========================================================================== + + def test_potential_energy(self, pot: LogarithmicPotential, x: gt.QVec3) -> None: + expect = Quantity(0.11027593, unit="kpc2 / Myr2") + assert qnp.isclose( + pot.potential_energy(x, t=0), expect, atol=Quantity(1e-8, expect.unit) + ) + + def test_gradient(self, pot: LogarithmicPotential, x: gt.QVec3) -> None: + expect = Quantity([0.00064902, 0.00129804, 0.00194706], "kpc / Myr2") + assert qnp.allclose( + pot.gradient(x, t=0), expect, atol=Quantity(1e-8, expect.unit) + ) + + def test_density(self, pot: LogarithmicPotential, x: gt.QVec3) -> None: + expect = Quantity(30321621.61178864, "solMass / kpc3") + assert qnp.isclose( + pot.density(x, t=0), expect, atol=Quantity(1e-8, expect.unit) + ) + + def test_hessian(self, pot: LogarithmicPotential, x: gt.QVec3) -> None: + expect = Quantity( + [ + [6.32377766e-04, -3.32830403e-05, -4.99245605e-05], + [-3.32830403e-05, 5.82453206e-04, -9.98491210e-05], + [-4.99245605e-05, -9.98491210e-05, 4.99245605e-04], + ], + "1/Myr2", + ) + assert qnp.allclose( + pot.hessian(x, t=0), expect, atol=Quantity(1e-8, expect.unit) + ) + + # --------------------------------- + # Convenience methods + + def test_tidal_tensor(self, pot: AbstractPotentialBase, x: gt.QVec3) -> None: + """Test the `AbstractPotentialBase.tidal_tensor` method.""" + expect = Quantity( + [ + [6.10189073e-05, -3.32830403e-05, -4.99245605e-05], + [-3.32830403e-05, 1.10943468e-05, -9.98491210e-05], + [-4.99245605e-05, -9.98491210e-05, -7.21132541e-05], + ], + "1/Myr2", + ) + assert qnp.allclose( + pot.tidal_tensor(x, t=0), expect, atol=Quantity(1e-8, expect.unit) + ) diff --git a/tests/unit/potential/builtin/misc/__init__.py b/tests/unit/potential/builtin/misc/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unit/potential/builtin/test_bar.py b/tests/unit/potential/builtin/misc/test_bar.py similarity index 96% rename from tests/unit/potential/builtin/test_bar.py rename to tests/unit/potential/builtin/misc/test_bar.py index bd15e97c..d6516b5a 100644 --- a/tests/unit/potential/builtin/test_bar.py +++ b/tests/unit/potential/builtin/misc/test_bar.py @@ -7,8 +7,8 @@ from unxt import AbstractUnitSystem, Quantity import galax.typing as gt -from ..test_core import TestAbstractPotential as AbstractPotential_Test -from .test_common import ( +from ...test_core import TestAbstractPotential as AbstractPotential_Test +from ..test_common import ( ParameterMTotMixin, ParameterShapeAMixin, ParameterShapeBMixin, diff --git a/tests/unit/potential/builtin/test_hernquist.py b/tests/unit/potential/builtin/misc/test_hernquist.py similarity index 90% rename from tests/unit/potential/builtin/test_hernquist.py rename to tests/unit/potential/builtin/misc/test_hernquist.py index 4fec7dc2..d0412193 100644 --- a/tests/unit/potential/builtin/test_hernquist.py +++ b/tests/unit/potential/builtin/misc/test_hernquist.py @@ -6,10 +6,9 @@ from unxt import Quantity import galax.typing as gt -from ..test_core import TestAbstractPotential as AbstractPotential_Test -from .test_common import ParameterMTotMixin, ParameterShapeCMixin -from galax.potential import HernquistPotential -from galax.potential._potential.base import AbstractPotentialBase +from ...test_core import TestAbstractPotential as AbstractPotential_Test +from ..test_common import ParameterMTotMixin, ParameterShapeCMixin +from galax.potential import AbstractPotentialBase, HernquistPotential class TestHernquistPotential( diff --git a/tests/unit/potential/builtin/test_isochrone.py b/tests/unit/potential/builtin/misc/test_isochrone.py similarity index 94% rename from tests/unit/potential/builtin/test_isochrone.py rename to tests/unit/potential/builtin/misc/test_isochrone.py index 901151c4..b740ba7f 100644 --- a/tests/unit/potential/builtin/test_isochrone.py +++ b/tests/unit/potential/builtin/misc/test_isochrone.py @@ -7,8 +7,8 @@ import galax.potential as gp import galax.typing as gt -from ..test_core import TestAbstractPotential as AbstractPotential_Test -from .test_common import ParameterMTotMixin, ParameterShapeBMixin +from ...test_core import TestAbstractPotential as AbstractPotential_Test +from ..test_common import ParameterMTotMixin, ParameterShapeBMixin from galax.potential import AbstractPotentialBase, IsochronePotential diff --git a/tests/unit/potential/builtin/test_jaffe.py b/tests/unit/potential/builtin/misc/test_jaffe.py similarity index 95% rename from tests/unit/potential/builtin/test_jaffe.py rename to tests/unit/potential/builtin/misc/test_jaffe.py index a0f076e4..582fb83d 100644 --- a/tests/unit/potential/builtin/test_jaffe.py +++ b/tests/unit/potential/builtin/misc/test_jaffe.py @@ -8,8 +8,8 @@ import galax.potential as gp import galax.typing as gt -from ..test_core import TestAbstractPotential as AbstractPotential_Test -from .test_common import ParameterMMixin, ParameterScaleRadiusMixin +from ...test_core import TestAbstractPotential as AbstractPotential_Test +from ..test_common import ParameterMMixin, ParameterScaleRadiusMixin from galax.potential import AbstractPotentialBase, JaffePotential diff --git a/tests/unit/potential/builtin/test_kepler.py b/tests/unit/potential/builtin/misc/test_kepler.py similarity index 95% rename from tests/unit/potential/builtin/test_kepler.py rename to tests/unit/potential/builtin/misc/test_kepler.py index c5c5137c..40eaa7c5 100644 --- a/tests/unit/potential/builtin/test_kepler.py +++ b/tests/unit/potential/builtin/misc/test_kepler.py @@ -6,8 +6,8 @@ from unxt import Quantity import galax.potential as gp -from ..test_core import TestAbstractPotential as AbstractPotential_Test -from .test_common import ParameterMTotMixin +from ...test_core import TestAbstractPotential as AbstractPotential_Test +from ..test_common import ParameterMTotMixin from galax.potential import AbstractPotentialBase, KeplerPotential from galax.typing import QVec3 diff --git a/tests/unit/potential/builtin/test_kuzmin.py b/tests/unit/potential/builtin/misc/test_kuzmin.py similarity index 96% rename from tests/unit/potential/builtin/test_kuzmin.py rename to tests/unit/potential/builtin/misc/test_kuzmin.py index c1b65815..2c5bc71c 100644 --- a/tests/unit/potential/builtin/test_kuzmin.py +++ b/tests/unit/potential/builtin/misc/test_kuzmin.py @@ -8,8 +8,8 @@ import galax.potential as gp import galax.typing as gt -from ..test_core import TestAbstractPotential as AbstractPotential_Test -from .test_common import ParameterMTotMixin, ParameterShapeAMixin +from ...test_core import TestAbstractPotential as AbstractPotential_Test +from ..test_common import ParameterMTotMixin, ParameterShapeAMixin from galax.potential import AbstractPotentialBase, KuzminPotential from galax.utils._optional_deps import HAS_GALA diff --git a/tests/unit/potential/builtin/test_leesutotriaxialnfw.py b/tests/unit/potential/builtin/misc/test_leesutotriaxialnfw.py similarity index 96% rename from tests/unit/potential/builtin/test_leesutotriaxialnfw.py rename to tests/unit/potential/builtin/misc/test_leesutotriaxialnfw.py index 31e0c263..425a475f 100644 --- a/tests/unit/potential/builtin/test_leesutotriaxialnfw.py +++ b/tests/unit/potential/builtin/misc/test_leesutotriaxialnfw.py @@ -9,10 +9,9 @@ import galax.potential as gp import galax.typing as gt -from ..param.test_field import ParameterFieldMixin -from ..test_core import TestAbstractPotential as AbstractPotential_Test -from .test_common import ParameterMMixin -from .test_nfw import ParameterScaleRadiusMixin +from ...param.test_field import ParameterFieldMixin +from ...test_core import TestAbstractPotential as AbstractPotential_Test +from ..test_common import ParameterMMixin, ParameterScaleRadiusMixin from galax.utils._optional_deps import HAS_GALA diff --git a/tests/unit/potential/builtin/test_longmuralibar.py b/tests/unit/potential/builtin/misc/test_longmuralibar.py similarity index 96% rename from tests/unit/potential/builtin/test_longmuralibar.py rename to tests/unit/potential/builtin/misc/test_longmuralibar.py index 147b908e..a8c0f646 100644 --- a/tests/unit/potential/builtin/test_longmuralibar.py +++ b/tests/unit/potential/builtin/misc/test_longmuralibar.py @@ -8,9 +8,9 @@ import galax.potential as gp import galax.typing as gt -from ..param.test_field import ParameterFieldMixin -from ..test_core import TestAbstractPotential as AbstractPotential_Test -from .test_common import ( +from ...param.test_field import ParameterFieldMixin +from ...test_core import TestAbstractPotential as AbstractPotential_Test +from ..test_common import ( ParameterMTotMixin, ParameterShapeAMixin, ParameterShapeBMixin, diff --git a/tests/unit/potential/builtin/test_miyamotonagai.py b/tests/unit/potential/builtin/misc/test_miyamotonagai.py similarity index 94% rename from tests/unit/potential/builtin/test_miyamotonagai.py rename to tests/unit/potential/builtin/misc/test_miyamotonagai.py index 983e7d08..188625d7 100644 --- a/tests/unit/potential/builtin/test_miyamotonagai.py +++ b/tests/unit/potential/builtin/misc/test_miyamotonagai.py @@ -7,8 +7,8 @@ from unxt import AbstractUnitSystem, Quantity import galax.potential as gp -from ..test_core import TestAbstractPotential as AbstractPotential_Test -from .test_common import ParameterMTotMixin, ParameterShapeAMixin, ParameterShapeBMixin +from ...test_core import TestAbstractPotential as AbstractPotential_Test +from ..test_common import ParameterMTotMixin, ParameterShapeAMixin, ParameterShapeBMixin from galax.potential import AbstractPotentialBase, MiyamotoNagaiPotential from galax.typing import Vec3 diff --git a/tests/unit/potential/builtin/test_null.py b/tests/unit/potential/builtin/misc/test_null.py similarity index 98% rename from tests/unit/potential/builtin/test_null.py rename to tests/unit/potential/builtin/misc/test_null.py index bc0f4ade..f43a9419 100644 --- a/tests/unit/potential/builtin/test_null.py +++ b/tests/unit/potential/builtin/misc/test_null.py @@ -10,7 +10,7 @@ import galax.potential as gp import galax.typing as gt -from ..test_core import TestAbstractPotential as AbstractPotential_Test +from ...test_core import TestAbstractPotential as AbstractPotential_Test class TestNullPotential(AbstractPotential_Test): diff --git a/tests/unit/potential/builtin/test_plummer.py b/tests/unit/potential/builtin/misc/test_plummer.py similarity index 75% rename from tests/unit/potential/builtin/test_plummer.py rename to tests/unit/potential/builtin/misc/test_plummer.py index 8fcf5046..304121d8 100644 --- a/tests/unit/potential/builtin/test_plummer.py +++ b/tests/unit/potential/builtin/misc/test_plummer.py @@ -8,10 +8,9 @@ import galax.potential as gp import galax.typing as gt -from ..test_core import TestAbstractPotential as AbstractPotential_Test -from .test_common import ParameterMTotMixin, ParameterShapeBMixin +from ...test_core import TestAbstractPotential as AbstractPotential_Test +from ..test_common import ParameterMTotMixin, ParameterShapeBMixin from galax.potential import AbstractPotentialBase, PlummerPotential -from galax.utils._optional_deps import HAS_GALA class TestPlummerPotential( @@ -84,26 +83,3 @@ def test_tidal_tensor(self, pot: AbstractPotentialBase, x: gt.QVec3) -> None: assert qnp.allclose( pot.tidal_tensor(x, t=0), expect, atol=Quantity(1e-8, expect.unit) ) - - # --------------------------------- - # Interoperability - - @pytest.mark.skipif(not HAS_GALA, reason="requires gala") - @pytest.mark.parametrize( - ("method0", "method1", "atol"), - [ - ("potential_energy", "energy", 1e-8), - ("gradient", "gradient", 1e-8), - ("density", "density", 1e-8), # TODO: why is this different? - ("hessian", "hessian", 1e-8), # TODO: why is gala's 0? - ], - ) - def test_method_gala( - self, - pot: PlummerPotential, - method0: str, - method1: str, - x: gt.QVec3, - atol: float, - ) -> None: - super().test_method_gala(pot, method0, method1, x, atol) diff --git a/tests/unit/potential/builtin/test_powerlawcutoff.py b/tests/unit/potential/builtin/misc/test_powerlawcutoff.py similarity index 92% rename from tests/unit/potential/builtin/test_powerlawcutoff.py rename to tests/unit/potential/builtin/misc/test_powerlawcutoff.py index a67f9bf7..a4432cfa 100644 --- a/tests/unit/potential/builtin/test_powerlawcutoff.py +++ b/tests/unit/potential/builtin/misc/test_powerlawcutoff.py @@ -8,9 +8,10 @@ import galax.potential as gp import galax.typing as gt -from ..param.test_field import ParameterFieldMixin -from ..test_core import TestAbstractPotential as AbstractPotential_Test -from .test_common import ParameterMTotMixin +from ...io.test_gala import parametrize_test_method_gala +from ...param.test_field import ParameterFieldMixin +from ...test_core import TestAbstractPotential as AbstractPotential_Test +from ..test_common import ParameterMTotMixin from galax.potential import AbstractPotentialBase, PowerLawCutoffPotential from galax.utils._optional_deps import GSL_ENABLED, HAS_GALA @@ -149,15 +150,7 @@ def test_galax_to_gala_to_galax_roundtrip( super().test_galax_to_gala_to_galax_roundtrip(pot, x) @pytest.mark.skipif(not HAS_GALA or not GSL_ENABLED, reason="requires gala + GSL") - @pytest.mark.parametrize( - ("method0", "method1", "atol"), - [ - ("potential_energy", "energy", 1e-8), - ("gradient", "gradient", 1e-8), - ("density", "density", 1e-8), - ("hessian", "hessian", 1e-8), - ], - ) + @parametrize_test_method_gala def test_method_gala( self, pot: PowerLawCutoffPotential, diff --git a/tests/unit/potential/builtin/test_satoh.py b/tests/unit/potential/builtin/misc/test_satoh.py similarity index 94% rename from tests/unit/potential/builtin/test_satoh.py rename to tests/unit/potential/builtin/misc/test_satoh.py index ce845474..9ef7d28a 100644 --- a/tests/unit/potential/builtin/test_satoh.py +++ b/tests/unit/potential/builtin/misc/test_satoh.py @@ -8,8 +8,8 @@ import galax.potential as gp import galax.typing as gt -from ..test_core import TestAbstractPotential as AbstractPotential_Test -from .test_common import ParameterMTotMixin, ParameterShapeAMixin, ParameterShapeBMixin +from ...test_core import TestAbstractPotential as AbstractPotential_Test +from ..test_common import ParameterMTotMixin, ParameterShapeAMixin, ParameterShapeBMixin from galax.potential import AbstractPotentialBase, SatohPotential diff --git a/tests/unit/potential/builtin/test_stone.py b/tests/unit/potential/builtin/misc/test_stone.py similarity index 96% rename from tests/unit/potential/builtin/test_stone.py rename to tests/unit/potential/builtin/misc/test_stone.py index 826c388c..9a35b969 100644 --- a/tests/unit/potential/builtin/test_stone.py +++ b/tests/unit/potential/builtin/misc/test_stone.py @@ -8,8 +8,8 @@ import galax.potential as gp import galax.typing as gt -from ..test_core import TestAbstractPotential as AbstractPotential_Test -from .test_common import ParameterFieldMixin, ParameterMTotMixin +from ...test_core import TestAbstractPotential as AbstractPotential_Test +from ..test_common import ParameterFieldMixin, ParameterMTotMixin from galax.potential import AbstractPotentialBase, StoneOstriker15Potential diff --git a/tests/unit/potential/builtin/test_triaxialhernquist.py b/tests/unit/potential/builtin/misc/test_triaxialhernquist.py similarity index 80% rename from tests/unit/potential/builtin/test_triaxialhernquist.py rename to tests/unit/potential/builtin/misc/test_triaxialhernquist.py index 7026d67b..49e757fb 100644 --- a/tests/unit/potential/builtin/test_triaxialhernquist.py +++ b/tests/unit/potential/builtin/misc/test_triaxialhernquist.py @@ -5,16 +5,15 @@ import quaxed.numpy as qnp from unxt import Quantity -from ..test_core import TestAbstractPotential as AbstractPotential_Test -from .test_common import ( +import galax.typing as gt +from ...test_core import TestAbstractPotential as AbstractPotential_Test +from ..test_common import ( ParameterMTotMixin, ParameterShapeCMixin, ParameterShapeQ1Mixin, ParameterShapeQ2Mixin, ) -from galax.potential import TriaxialHernquistPotential -from galax.potential._potential.base import AbstractPotentialBase -from galax.typing import Vec3 +from galax.potential import AbstractPotentialBase, TriaxialHernquistPotential class TestTriaxialHernquistPotential( @@ -43,13 +42,15 @@ def fields_( # ========================================================================== - def test_potential_energy(self, pot: TriaxialHernquistPotential, x: Vec3) -> None: + def test_potential_energy( + self, pot: TriaxialHernquistPotential, x: gt.QVec3 + ) -> None: expect = Quantity(-0.61215074, pot.units["specific energy"]) assert qnp.isclose( pot.potential_energy(x, t=0), expect, atol=Quantity(1e-8, expect.unit) ) - def test_gradient(self, pot: TriaxialHernquistPotential, x: Vec3) -> None: + def test_gradient(self, pot: TriaxialHernquistPotential, x: gt.QVec3) -> None: expect = Quantity( [0.01312095, 0.02168751, 0.15745134], pot.units["acceleration"] ) @@ -58,10 +59,10 @@ def test_gradient(self, pot: TriaxialHernquistPotential, x: Vec3) -> None: ) @pytest.mark.xfail(reason="WFF?") - def test_density(self, pot: TriaxialHernquistPotential, x: Vec3) -> None: + def test_density(self, pot: TriaxialHernquistPotential, x: gt.QVec3) -> None: assert pot.density(x, t=0).decompose(pot.units).value >= 0 - def test_hessian(self, pot: TriaxialHernquistPotential, x: Vec3) -> None: + def test_hessian(self, pot: TriaxialHernquistPotential, x: gt.QVec3) -> None: expect = Quantity( [ [0.01223294, -0.00146778, -0.0106561], @@ -75,7 +76,7 @@ def test_hessian(self, pot: TriaxialHernquistPotential, x: Vec3) -> None: # --------------------------------- # Convenience methods - def test_tidal_tensor(self, pot: AbstractPotentialBase, x: Vec3) -> None: + def test_tidal_tensor(self, pot: AbstractPotentialBase, x: gt.QVec3) -> None: """Test the `AbstractPotentialBase.tidal_tensor` method.""" expect = Quantity( [ diff --git a/tests/unit/potential/builtin/nfw/__init__.py b/tests/unit/potential/builtin/nfw/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unit/potential/builtin/test_nfw.py b/tests/unit/potential/builtin/nfw/test_nfw.py similarity index 95% rename from tests/unit/potential/builtin/test_nfw.py rename to tests/unit/potential/builtin/nfw/test_nfw.py index 033db0e7..3cdb7360 100644 --- a/tests/unit/potential/builtin/test_nfw.py +++ b/tests/unit/potential/builtin/nfw/test_nfw.py @@ -10,8 +10,8 @@ import galax.potential as gp import galax.typing as gt -from ..test_core import TestAbstractPotential as AbstractPotential_Test -from .test_common import ParameterMMixin, ParameterScaleRadiusMixin +from ...test_core import TestAbstractPotential as AbstractPotential_Test +from ..test_common import ParameterMMixin, ParameterScaleRadiusMixin ############################################################################### diff --git a/tests/unit/potential/builtin/test_triaxialnfw.py b/tests/unit/potential/builtin/nfw/test_triaxialnfw.py similarity index 96% rename from tests/unit/potential/builtin/test_triaxialnfw.py rename to tests/unit/potential/builtin/nfw/test_triaxialnfw.py index bc5aafea..dea93190 100644 --- a/tests/unit/potential/builtin/test_triaxialnfw.py +++ b/tests/unit/potential/builtin/nfw/test_triaxialnfw.py @@ -8,8 +8,8 @@ import galax.potential as gp import galax.typing as gt -from ..test_core import TestAbstractPotential as AbstractPotential_Test -from .test_common import ( +from ...test_core import TestAbstractPotential as AbstractPotential_Test +from ..test_common import ( ParameterMMixin, ParameterScaleRadiusMixin, ParameterShapeQ1Mixin, diff --git a/tests/unit/potential/builtin/test_vogelsberger08nfw.py b/tests/unit/potential/builtin/nfw/test_vogelsberger08nfw.py similarity index 96% rename from tests/unit/potential/builtin/test_vogelsberger08nfw.py rename to tests/unit/potential/builtin/nfw/test_vogelsberger08nfw.py index a0dbbf22..1db5f36e 100644 --- a/tests/unit/potential/builtin/test_vogelsberger08nfw.py +++ b/tests/unit/potential/builtin/nfw/test_vogelsberger08nfw.py @@ -8,9 +8,9 @@ import galax.potential as gp import galax.typing as gt -from ..param.test_field import ParameterFieldMixin -from ..test_core import TestAbstractPotential as AbstractPotential_Test -from .test_common import ( +from ...param.test_field import ParameterFieldMixin +from ...test_core import TestAbstractPotential as AbstractPotential_Test +from ..test_common import ( ParameterMMixin, ParameterScaleRadiusMixin, ParameterShapeQ1Mixin, diff --git a/tests/unit/potential/builtin/special/__init__.py b/tests/unit/potential/builtin/special/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unit/potential/special/test_bovymwpotential2014.py b/tests/unit/potential/builtin/special/test_bovymwpotential2014.py similarity index 92% rename from tests/unit/potential/special/test_bovymwpotential2014.py rename to tests/unit/potential/builtin/special/test_bovymwpotential2014.py index bf27e51a..352f853a 100644 --- a/tests/unit/potential/special/test_bovymwpotential2014.py +++ b/tests/unit/potential/builtin/special/test_bovymwpotential2014.py @@ -11,7 +11,8 @@ import galax.potential as gp import galax.typing as gt -from ..test_composite import AbstractCompositePotential_Test +from ...io.test_gala import parametrize_test_method_gala +from ...test_composite import AbstractCompositePotential_Test from galax.potential import BovyMWPotential2014 from galax.utils._optional_deps import GSL_ENABLED, HAS_GALA @@ -109,15 +110,7 @@ def test_galax_to_gala_to_galax_roundtrip( super().test_galax_to_gala_to_galax_roundtrip(pot, x) @pytest.mark.skipif(not HAS_GALA or not GSL_ENABLED, reason="requires gala + GSL") - @pytest.mark.parametrize( - ("method0", "method1", "atol"), - [ - ("potential_energy", "energy", 1e-8), - ("gradient", "gradient", 1e-8), - ("density", "density", 1e-8), - ("hessian", "hessian", 1e-8), - ], - ) + @parametrize_test_method_gala def test_method_gala( self, pot: BovyMWPotential2014, diff --git a/tests/unit/potential/builtin/special/test_lm10.py b/tests/unit/potential/builtin/special/test_lm10.py new file mode 100644 index 00000000..e91a8f91 --- /dev/null +++ b/tests/unit/potential/builtin/special/test_lm10.py @@ -0,0 +1,126 @@ +from __future__ import annotations + +from collections.abc import Mapping +from typing import TYPE_CHECKING + +import pytest +from typing_extensions import override + +import quaxed.numpy as qnp +from unxt import Quantity +from unxt.unitsystems import galactic + +import galax.potential as gp +import galax.typing as gt +from ...test_composite import AbstractCompositePotential_Test +from galax.potential import AbstractCompositePotential, LM10Potential +from galax.utils._optional_deps import HAS_GALA + +if TYPE_CHECKING: + from galax.potential import AbstractPotentialBase + + +class TestLM10Potential(AbstractCompositePotential_Test): + """Test the `galax.potential.LM10Potential` class.""" + + @pytest.fixture(scope="class") + def pot_cls(self) -> type[gp.LM10Potential]: + return gp.LM10Potential + + @pytest.fixture(scope="class") + def pot_map(self, pot_cls: type[LM10Potential]) -> dict[str, dict[str, Quantity]]: + """Composite potential.""" + return { + "disk": pot_cls._default_disk, + "bulge": pot_cls._default_bulge, + "halo": pot_cls._default_halo, + } + + # ========================================================================== + + @override + def test_init_units_from_args( + self, + pot_cls: type[AbstractCompositePotential], + pot_map: Mapping[str, AbstractPotentialBase], + ) -> None: + """Test unit system from None.""" + pot = pot_cls(**pot_map, units=None) + assert pot.units == galactic + + # ========================================================================== + + def test_potential_energy(self, pot: LM10Potential, x: gt.QVec3) -> None: + expect = Quantity(-0.00242568, unit="kpc2 / Myr2") + assert qnp.isclose( + pot.potential_energy(x, t=0), expect, atol=Quantity(1e-8, expect.unit) + ) + + def test_gradient(self, pot: LM10Potential, x: gt.QVec3) -> None: + expect = Quantity([0.00278038, 0.00533753, 0.0111171], "kpc / Myr2") + assert qnp.allclose( + pot.gradient(x, t=0), expect, atol=Quantity(1e-8, expect.unit) + ) + + def test_density(self, pot: LM10Potential, x: gt.QVec3) -> None: + expect = Quantity(19085831.78310305, "solMass / kpc3") + assert qnp.isclose( + pot.density(x, t=0), expect, atol=Quantity(1e-8, expect.unit) + ) + + def test_hessian(self, pot: LM10Potential, x: gt.QVec3) -> None: + expect = Quantity( + [ + [0.00234114, -0.00081663, -0.0013405], + [-0.00081663, 0.00100949, -0.00267623], + [-0.0013405, -0.00267623, -0.00227171], + ], + "1/Myr2", + ) + assert qnp.allclose( + pot.hessian(x, t=0), expect, atol=Quantity(1e-8, expect.unit) + ) + + # --------------------------------- + # Convenience methods + + def test_tidal_tensor(self, pot: AbstractPotentialBase, x: gt.QVec3) -> None: + """Test the `AbstractPotentialBase.tidal_tensor` method.""" + expect = Quantity( + [ + [0.0019815, -0.00081663, -0.0013405], + [-0.00081663, 0.00064985, -0.00267623], + [-0.0013405, -0.00267623, -0.00263135], + ], + "1/Myr2", + ) + assert qnp.allclose( + pot.tidal_tensor(x, t=0), expect, atol=Quantity(1e-8, expect.unit) + ) + + # ========================================================================== + # Interoperability + + @pytest.mark.skipif(not HAS_GALA, reason="requires gala") + @pytest.mark.parametrize( + ("method0", "method1", "atol"), + [ + ("potential_energy", "energy", 1e-8), + ("gradient", "gradient", 1e-8), + # ("density", "density", 1e-8), # TODO: get gala and galax to agree + # ("hessian", "hessian", 1e-8), # TODO: get gala and galax to agree + ], + ) + def test_method_gala( + self, + pot: gp.AbstractPotentialBase, + method0: str, + method1: str, + x: gt.QVec3, + atol: float, + ) -> None: + """Test the equivalence of methods between gala and galax. + + This test only runs if the potential can be mapped to gala. + """ + super().test_method_gala(pot, method0, method1, x, atol) diff --git a/tests/unit/potential/special/test_milkywaypotential.py b/tests/unit/potential/builtin/special/test_milkywaypotential.py similarity index 98% rename from tests/unit/potential/special/test_milkywaypotential.py rename to tests/unit/potential/builtin/special/test_milkywaypotential.py index eefaa00f..bc6c2c01 100644 --- a/tests/unit/potential/special/test_milkywaypotential.py +++ b/tests/unit/potential/builtin/special/test_milkywaypotential.py @@ -11,7 +11,7 @@ from unxt.unitsystems import galactic import galax.typing as gt -from ..test_composite import AbstractCompositePotential_Test +from ...test_composite import AbstractCompositePotential_Test from galax.potential import MilkyWayPotential if TYPE_CHECKING: diff --git a/tests/unit/potential/builtin/test_common.py b/tests/unit/potential/builtin/test_common.py index f6300152..e41b688b 100644 --- a/tests/unit/potential/builtin/test_common.py +++ b/tests/unit/potential/builtin/test_common.py @@ -174,6 +174,29 @@ def test_q2_userfunc(self, pot_cls, fields): assert pot.q2(t=0) == Quantity(1.2, "") +class ParameterShapeQ3Mixin(ParameterFieldMixin): + """Test the shape parameter.""" + + @pytest.fixture(scope="class") + def field_q3(self) -> float: + return Quantity(0.5, "") + + # ===================================================== + + def test_q3_constant(self, pot_cls, fields): + """Test the mass parameter.""" + fields["q3"] = Quantity(0.6, "") + pot = pot_cls(**fields) + assert pot.q3(t=0) == Quantity(0.6, "") + + @pytest.mark.xfail(reason="TODO: user function doesn't have units") + def test_q3_userfunc(self, pot_cls, fields): + """Test the mass parameter.""" + fields["q3"] = lambda t: t * 1.2 + pot = pot_cls(**fields) + assert pot.q3(t=0) == Quantity(1.2, "") + + # ============================================================================= diff --git a/tests/unit/potential/builtin/test_logarithmichalo.py b/tests/unit/potential/builtin/test_logarithmichalo.py deleted file mode 100644 index 804dfca2..00000000 --- a/tests/unit/potential/builtin/test_logarithmichalo.py +++ /dev/null @@ -1,182 +0,0 @@ -from typing import Any - -import astropy.units as u -import pytest - -import quaxed.numpy as qnp -from unxt import AbstractUnitSystem, Quantity -from unxt.unitsystems import galactic - -import galax.potential as gp -import galax.typing as gt -from ..param.test_field import ParameterFieldMixin -from ..test_core import TestAbstractPotential as AbstractPotential_Test -from galax.potential import ( - AbstractPotentialBase, - ConstantParameter, - LogarithmicPotential, -) -from galax.utils._optional_deps import HAS_GALA - - -class ParameterVCMixin(ParameterFieldMixin): - """Test the circular velocity parameter.""" - - pot_cls: type[gp.AbstractPotential] - - @pytest.fixture(scope="class") - def field_v_c(self) -> Quantity["speed"]: - return Quantity(220, "km/s") - - # ===================================================== - - def test_v_c_units(self, pot_cls, fields): - """Test the speed parameter.""" - fields["v_c"] = Quantity(1.0, u.Unit(220 * u.km / u.s)) - fields["units"] = galactic - pot = pot_cls(**fields) - assert isinstance(pot.v_c, ConstantParameter) - assert pot.v_c.value == Quantity(220, "km/s") - - def test_v_c_constant(self, pot_cls, fields): - """Test the speed parameter.""" - fields["v_c"] = Quantity(1.0, "km/s") - pot = pot_cls(**fields) - assert pot.v_c(t=0) == Quantity(1.0, "km/s") - - @pytest.mark.xfail(reason="TODO: user function doesn't have units") - def test_v_c_userfunc(self, pot_cls, fields): - """Test the mass parameter.""" - fields["v_c"] = lambda t: t + 2 - pot = pot_cls(**fields) - assert pot.v_c(t=0) == 2 - - -class ParameterRHMixin(ParameterFieldMixin): - """Test the scale radius parameter.""" - - pot_cls: type[gp.AbstractPotential] - - @pytest.fixture(scope="class") - def field_r_h(self) -> Quantity["length"]: - return Quantity(8, "kpc") - - # ===================================================== - - def test_r_h_units(self, pot_cls, fields): - """Test the speed parameter.""" - fields["r_h"] = Quantity(1, u.Unit(10 * u.kpc)) - fields["units"] = galactic - pot = pot_cls(**fields) - assert isinstance(pot.r_h, ConstantParameter) - assert qnp.isclose( - pot.r_h.value, Quantity(10, "kpc"), atol=Quantity(1e-15, "kpc") - ) - - def test_r_h_constant(self, pot_cls, fields): - """Test the speed parameter.""" - fields["r_h"] = Quantity(11.0, "kpc") - pot = pot_cls(**fields) - assert pot.r_h(t=0) == Quantity(11.0, "kpc") - - @pytest.mark.xfail(reason="TODO: user function doesn't have units") - def test_r_h_userfunc(self, pot_cls, fields): - """Test the mass parameter.""" - fields["r_h"] = lambda t: t + 2 - pot = pot_cls(**fields) - assert pot.r_h(t=0) == 2 - - -class TestLogarithmicPotential( - AbstractPotential_Test, - # Parameters - ParameterVCMixin, - ParameterRHMixin, -): - """Test the `galax.potential.LogarithmicPotential` class.""" - - @pytest.fixture(scope="class") - def pot_cls(self) -> type[gp.LogarithmicPotential]: - return gp.LogarithmicPotential - - @pytest.fixture(scope="class") - def fields_( - self, - field_v_c: u.Quantity, - field_r_h: u.Quantity, - field_units: AbstractUnitSystem, - ) -> dict[str, Any]: - return {"v_c": field_v_c, "r_h": field_r_h, "units": field_units} - - # ========================================================================== - - def test_potential_energy(self, pot: LogarithmicPotential, x: gt.QVec3) -> None: - expect = Quantity(0.11027593, unit="kpc2 / Myr2") - assert qnp.isclose( - pot.potential_energy(x, t=0), expect, atol=Quantity(1e-8, expect.unit) - ) - - def test_gradient(self, pot: LogarithmicPotential, x: gt.QVec3) -> None: - expect = Quantity([0.00064902, 0.00129804, 0.00194706], "kpc / Myr2") - assert qnp.allclose( - pot.gradient(x, t=0), expect, atol=Quantity(1e-8, expect.unit) - ) - - def test_density(self, pot: LogarithmicPotential, x: gt.QVec3) -> None: - expect = Quantity(30321621.61178864, "solMass / kpc3") - assert qnp.isclose( - pot.density(x, t=0), expect, atol=Quantity(1e-8, expect.unit) - ) - - def test_hessian(self, pot: LogarithmicPotential, x: gt.QVec3) -> None: - expect = Quantity( - [ - [6.32377766e-04, -3.32830403e-05, -4.99245605e-05], - [-3.32830403e-05, 5.82453206e-04, -9.98491210e-05], - [-4.99245605e-05, -9.98491210e-05, 4.99245605e-04], - ], - "1/Myr2", - ) - assert qnp.allclose( - pot.hessian(x, t=0), expect, atol=Quantity(1e-8, expect.unit) - ) - - # --------------------------------- - # Convenience methods - - def test_tidal_tensor(self, pot: AbstractPotentialBase, x: gt.QVec3) -> None: - """Test the `AbstractPotentialBase.tidal_tensor` method.""" - expect = Quantity( - [ - [6.10189073e-05, -3.32830403e-05, -4.99245605e-05], - [-3.32830403e-05, 1.10943468e-05, -9.98491210e-05], - [-4.99245605e-05, -9.98491210e-05, -7.21132541e-05], - ], - "1/Myr2", - ) - assert qnp.allclose( - pot.tidal_tensor(x, t=0), expect, atol=Quantity(1e-8, expect.unit) - ) - - # --------------------------------- - # Interoperability - - @pytest.mark.skipif(not HAS_GALA, reason="requires gala") - @pytest.mark.parametrize( - ("method0", "method1", "atol"), - [ - ("potential_energy", "energy", 1e-8), - ("gradient", "gradient", 1e-8), - ("density", "density", 1e-8), # TODO: why is this different? - ("hessian", "hessian", 1e-8), # TODO: why is gala's 0? - ], - ) - def test_method_gala( - self, - pot: LogarithmicPotential, - method0: str, - method1: str, - x: gt.QVec3, - atol: float, - ) -> None: - super().test_method_gala(pot, method0, method1, x, atol) diff --git a/tests/unit/potential/io/gala_helper.py b/tests/unit/potential/io/gala_helper.py index f9bbf6c7..4cefb831 100644 --- a/tests/unit/potential/io/gala_helper.py +++ b/tests/unit/potential/io/gala_helper.py @@ -56,7 +56,7 @@ def galax_to_gala(pot: gpx.AbstractPotentialBase, /) -> gp.PotentialBase: """ msg = ( "`galax_to_gala` does not have a registered function to convert " - f"{pot.__class__.__name__!r} to a `gala.PotentialBase` instance." + f"{pot.__class__.__name__!r} to a galax potential." ) raise NotImplementedError(msg) @@ -80,7 +80,6 @@ def _galax_to_gala_composite(pot: gpx.CompositePotential, /) -> gp.CompositePote @galax_to_gala.register(gpx.IsochronePotential) @galax_to_gala.register(gpx.KeplerPotential) @galax_to_gala.register(gpx.KuzminPotential) -@galax_to_gala.register(gpx.LogarithmicPotential) @galax_to_gala.register(gpx.MiyamotoNagaiPotential) @galax_to_gala.register(gpx.PlummerPotential) @galax_to_gala.register(gpx.PowerLawCutoffPotential) @@ -187,6 +186,46 @@ def _galax_to_gala_stoneostriker15( ) +# ----------------------------------------------------------------------------- +# Logarithmic potentials + + +@galax_to_gala.register +def _galax_to_gala_logarithmic( + pot: gpx.LogarithmicPotential, / +) -> gp.LogarithmicPotential: + """Convert a Galax LogarithmicPotential to a Gala potential.""" + if not _all_constant_parameters(pot, "v_c", "r_s"): + msg = "Gala does not support time-dependent parameters." + raise TypeError(msg) + + return gp.LogarithmicPotential( + v_c=convert(pot.v_c(0), APYQuantity), + r_h=convert(pot.r_s(0), APYQuantity), + units=galax_to_gala_units(pot.units), + ) + + +@galax_to_gala.register +def _galax_to_gala_logarithmic( + pot: gpx.LMJ09LogarithmicPotential, / +) -> gp.LogarithmicPotential: + """Convert a Galax LogarithmicPotential to a Gala potential.""" + if not _all_constant_parameters(pot, "v_c", "r_s", "q1", "q2", "q3", "phi"): + msg = "Gala does not support time-dependent parameters." + raise TypeError(msg) + + return gp.LogarithmicPotential( + v_c=convert(pot.v_c(0), APYQuantity), + r_h=convert(pot.r_s(0), APYQuantity), + q1=convert(pot.q1(0), APYQuantity), + q2=convert(pot.q2(0), APYQuantity), + q3=convert(pot.q3(0), APYQuantity), + phi=convert(pot.phi(0), APYQuantity), + units=galax_to_gala_units(pot.units), + ) + + # ----------------------------------------------------------------------------- # NFW potentials @@ -251,6 +290,27 @@ def rename(k: str) -> str: ) +@galax_to_gala.register +def _galax_to_gala_lm10(pot: gpx.LM10Potential, /) -> gp.LM10Potential: + """Convert a Galax LM10Potential to a Gala potential.""" + + def rename(k: str) -> str: + match k: + case "m_tot": + return "m" + case "r_s": + return "r_h" + case _: + return k + + return gp.LM10Potential( + **{ + c: {rename(k): getattr(p, k)(0) for k in p.parameters} + for c, p in pot.items() + } + ) + + @galax_to_gala.register def _galax_to_gala_mwpotential(pot: gpx.MilkyWayPotential, /) -> gp.MilkyWayPotential: """Convert a Galax MilkyWayPotential to a Gala potential.""" diff --git a/tests/unit/potential/io/test_gala.py b/tests/unit/potential/io/test_gala.py index b5322c77..9b72d2ea 100644 --- a/tests/unit/potential/io/test_gala.py +++ b/tests/unit/potential/io/test_gala.py @@ -1,8 +1,5 @@ """Testing the gala potential I/O module.""" -from inspect import get_annotations -from typing import ClassVar - import astropy.units as u import pytest from plum import convert @@ -38,15 +35,16 @@ class GalaIOMixin: This is mixed into the ``TestAbstractPotentialBase`` class. """ + # TODO: get this working again # All the Gala-mapped potentials - _GALA_CAN_MAP_TO: ClassVar = set( - [ # get from GALA_TO_GALAX_REGISTRY or the single-dispatch registry - _GALA_TO_GALAX_REGISTRY.get(pot, get_annotations(func)["return"]) - for pot, func in gp.io.gala_to_galax.registry.items() - ] - if HAS_GALA - else [] - ) + # _GALA_CAN_MAP_TO: ClassVar = set( + # [ # get from GALA_TO_GALAX_REGISTRY or the single-dispatch registry + # _GALA_TO_GALAX_REGISTRY.get(pot, get_annotations(func)["return"]) + # for pot, func in gp.io.gala_to_galax.registry.items() + # ] + # if HAS_GALA + # else [] + # ) @pytest.mark.skipif(not HAS_GALA, reason="requires gala") def test_galax_to_gala_to_galax_roundtrip( @@ -56,7 +54,10 @@ def test_galax_to_gala_to_galax_roundtrip( from .gala_helper import galax_to_gala # First we need to check that the potential is gala-compatible - if type(pot) not in self._GALA_CAN_MAP_TO: + # if type(pot) not in self._GALA_CAN_MAP_TO: + try: + galax_to_gala(pot) + except NotImplementedError: pytest.skip(f"potential {pot} cannot be mapped to from gala") rpot = gp.io.gala_to_galax(galax_to_gala(pot)) @@ -87,8 +88,11 @@ def test_method_gala( """ from ..io.gala_helper import galax_to_gala - if type(pot) not in self._GALA_CAN_MAP_TO: - pytest.skip(f"potential {pot} cannot be mapped to gala") + # if type(pot) not in self._GALA_CAN_MAP_TO: + try: + galax_to_gala(pot) + except NotImplementedError: + pytest.skip(f"potential {pot} cannot be mapped to from gala") galax = getattr(pot, method0)(x, t=0) gala = getattr(galax_to_gala(pot), method1)(convert(x, u.Quantity), t=0 * u.Myr) diff --git a/tests/unit/potential/test_composite.py b/tests/unit/potential/test_composite.py index e5f38189..c6e44130 100644 --- a/tests/unit/potential/test_composite.py +++ b/tests/unit/potential/test_composite.py @@ -14,12 +14,10 @@ from unxt.unitsystems import UnitSystem, dimensionless, galactic, solarsystem import galax.potential as gp -import galax.typing as gt from .test_base import TestAbstractPotentialBase as AbstractPotentialBase_Test from .test_utils import FieldUnitSystemMixin from galax.typing import Vec3 from galax.utils._misc import first -from galax.utils._optional_deps import HAS_GALA if TYPE_CHECKING: from galax.potential import ( @@ -213,31 +211,6 @@ def test_add_compot(self, pot: CompositePotential) -> None: assert newkey == "kep2" assert newvalue is newpot["kep2"] - # ========================================================================== - - # --------------------------------- - # Interoperability - - @pytest.mark.skipif(not HAS_GALA, reason="requires gala") - @pytest.mark.parametrize( - ("method0", "method1", "atol"), - [ - ("potential_energy", "energy", 1e-8), - ("gradient", "gradient", 1e-8), - ("density", "density", 1e-8), - ("hessian", "hessian", 1e-8), - ], - ) - def test_method_gala( - self, - pot: AbstractCompositePotential, - method0: str, - method1: str, - x: gt.QVec3, - atol: float, - ) -> None: - super().test_method_gala(pot, method0, method1, x, atol) - class TestCompositePotential(AbstractCompositePotential_Test): """Test the `galax.potential.CompositePotential` class."""