From 1a58b68ad93c859c27365297fb7dbba44d25671b Mon Sep 17 00:00:00 2001 From: Nathaniel Starkman Date: Sat, 16 Mar 2024 20:13:15 -0400 Subject: [PATCH] feat: move units to unxt (#224) * feat: move units to unxt * fix: rebase and smoke * fix: digits Signed-off-by: nstarman --- pyproject.toml | 2 +- src/galax/__init__.py | 3 +- src/galax/__init__.pyi | 2 - src/galax/coordinates/_psp/base.py | 11 +- src/galax/coordinates/_psp/operator_compat.py | 2 +- src/galax/coordinates/_psp/psp.py | 3 +- .../dynamics/_dynamics/integrate/_api.py | 3 +- .../dynamics/_dynamics/integrate/_base.py | 3 +- .../dynamics/_dynamics/integrate/_builtin.py | 3 +- .../dynamics/_dynamics/integrate/_funcs.py | 2 +- .../_dynamics/mockstream/df/fardal.py | 6 +- .../mockstream/mockstream_generator.py | 3 +- src/galax/potential/_potential/base.py | 3 +- src/galax/potential/_potential/builtin.py | 9 +- src/galax/potential/_potential/composite.py | 3 +- src/galax/potential/_potential/core.py | 3 +- src/galax/potential/_potential/frame.py | 5 +- src/galax/potential/_potential/special.py | 4 +- src/galax/potential/_potential/utils.py | 2 +- src/galax/units.py | 306 ------------------ tests/functional/test_mockstreamgenerator.py | 3 +- tests/smoke/test_package.py | 1 - tests/unit/coordinates/psp/test_base.py | 4 +- tests/unit/dynamics/test_orbit.py | 2 +- tests/unit/potential/builtin/test_bar.py | 3 +- tests/unit/potential/builtin/test_common.py | 2 +- .../potential/builtin/test_miyamotonagai.py | 3 +- tests/unit/potential/builtin/test_nfw.py | 2 +- tests/unit/potential/builtin/test_null.py | 13 +- tests/unit/potential/io/gala_helper.py | 3 +- tests/unit/potential/test_base.py | 4 +- tests/unit/potential/test_composite.py | 2 +- tests/unit/potential/test_core.py | 2 +- tests/unit/potential/test_special.py | 2 +- tests/unit/potential/test_utils.py | 4 +- tests/unit/test_units.py | 99 ------ 36 files changed, 54 insertions(+), 473 deletions(-) delete mode 100644 src/galax/units.py delete mode 100644 tests/unit/test_units.py diff --git a/pyproject.toml b/pyproject.toml index 86c49d39..6bc11558 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,7 +25,7 @@ "Typing :: Typed", ] dependencies = [ - "astropy >= 6", + "astropy >= 6.0", "beartype", "coordinax @ git+https://github.com/GalacticDynamics/coordinax.git", "diffrax >= 0.5", diff --git a/src/galax/__init__.py b/src/galax/__init__.py index 2eec5f9c..e289de6d 100644 --- a/src/galax/__init__.py +++ b/src/galax/__init__.py @@ -6,14 +6,13 @@ "coordinates", "potential", "dynamics", - "units", "utils", "typing", ] from jax import config -from . import coordinates, dynamics, potential, typing, units, utils +from . import coordinates, dynamics, potential, typing, utils from ._version import __version__ config.update("jax_enable_x64", True) # noqa: FBT003 diff --git a/src/galax/__init__.pyi b/src/galax/__init__.pyi index 9ef4d532..20c9af4f 100644 --- a/src/galax/__init__.pyi +++ b/src/galax/__init__.pyi @@ -5,7 +5,6 @@ __all__ = [ "dynamics", "potential", "typing", - "units", "utils", ] @@ -13,7 +12,6 @@ from . import ( dynamics as dynamics, potential as potential, typing as typing, - units as units, utils as utils, ) from ._version import ( # type: ignore[attr-defined] diff --git a/src/galax/coordinates/_psp/base.py b/src/galax/coordinates/_psp/base.py index cc882301..1f260247 100644 --- a/src/galax/coordinates/_psp/base.py +++ b/src/galax/coordinates/_psp/base.py @@ -20,7 +20,7 @@ Cartesian3DVector, represent_as as vector_represent_as, ) -from unxt import Quantity +from unxt import Quantity, unitsystem from .utils import getitem_broadscalartime_index from galax.typing import ( @@ -30,7 +30,6 @@ BatchVec7, BroadBatchFloatQScalar, ) -from galax.units import unitsystem if TYPE_CHECKING: from typing import Self @@ -188,8 +187,8 @@ def w(self, *, units: Any) -> BatchVec6: Parameters ---------- - units : `galax.units.UnitSystem`, optional keyword-only - The unit system. :func:`~galax.units.unitsystem` is used to + units : `unxt.UnitSystem`, optional keyword-only + The unit system. :func:`~unxt.unitsystem` is used to convert the input to a unit system. Returns @@ -228,8 +227,8 @@ def wt(self, *, units: Any) -> BatchVec7: Parameters ---------- - units : `galax.units.UnitSystem`, keyword-only - The unit system. :func:`~galax.units.unitsystem` is used to + units : `unxt.UnitSystem`, keyword-only + The unit system. :func:`~unxt.unitsystem` is used to convert the input to a unit system. Returns diff --git a/src/galax/coordinates/_psp/operator_compat.py b/src/galax/coordinates/_psp/operator_compat.py index 22912827..37f73aa2 100644 --- a/src/galax/coordinates/_psp/operator_compat.py +++ b/src/galax/coordinates/_psp/operator_compat.py @@ -176,7 +176,7 @@ def call( Quantity['length'](Array(2., dtype=float64), unit='kpc') >>> newpsp.t.to("Myr") - Quantity['time'](Array(6.52312755, dtype=float64), unit='Myr') + Quantity['time'](Array(6.52312732, dtype=float64), unit='Myr') This spatial translation is time independent. diff --git a/src/galax/coordinates/_psp/psp.py b/src/galax/coordinates/_psp/psp.py index 23676104..b91da307 100644 --- a/src/galax/coordinates/_psp/psp.py +++ b/src/galax/coordinates/_psp/psp.py @@ -8,7 +8,7 @@ import jax.numpy as jnp from coordinax import Abstract3DVector, Abstract3DVectorDifferential -from unxt import Quantity +from unxt import Quantity, UnitSystem from .base import AbstractPhaseSpacePosition from .utils import _p_converter, _q_converter @@ -20,7 +20,6 @@ QVec1, VecTime, ) -from galax.units import UnitSystem from galax.utils._shape import batched_shape, expand_batch_dims, vector_batched_shape diff --git a/src/galax/dynamics/_dynamics/integrate/_api.py b/src/galax/dynamics/_dynamics/integrate/_api.py index ddb9b4bb..a5a483b4 100644 --- a/src/galax/dynamics/_dynamics/integrate/_api.py +++ b/src/galax/dynamics/_dynamics/integrate/_api.py @@ -2,9 +2,10 @@ from typing import Any, Protocol, runtime_checkable +from unxt import UnitSystem + import galax.typing as gt from galax.coordinates import AbstractPhaseSpacePosition, PhaseSpacePosition -from galax.units import UnitSystem from galax.utils.dataclasses import _DataclassInstance diff --git a/src/galax/dynamics/_dynamics/integrate/_base.py b/src/galax/dynamics/_dynamics/integrate/_base.py index 72e3922d..caca8108 100644 --- a/src/galax/dynamics/_dynamics/integrate/_base.py +++ b/src/galax/dynamics/_dynamics/integrate/_base.py @@ -4,10 +4,11 @@ import equinox as eqx +from unxt import UnitSystem + from ._api import FCallable from galax.coordinates import AbstractPhaseSpacePosition, PhaseSpacePosition from galax.typing import BatchQVecTime, BatchVec6, BatchVecTime, QVecTime, VecTime -from galax.units import UnitSystem class AbstractIntegrator(eqx.Module, strict=True): # type: ignore[call-arg, misc] diff --git a/src/galax/dynamics/_dynamics/integrate/_builtin.py b/src/galax/dynamics/_dynamics/integrate/_builtin.py index eb20c773..36d0cf25 100644 --- a/src/galax/dynamics/_dynamics/integrate/_builtin.py +++ b/src/galax/dynamics/_dynamics/integrate/_builtin.py @@ -10,13 +10,12 @@ import jax import quaxed.array_api as xp -from unxt import Quantity +from unxt import Quantity, UnitSystem import galax.typing as gt from ._api import FCallable from ._base import AbstractIntegrator from galax.coordinates import AbstractPhaseSpacePosition, PhaseSpacePosition -from galax.units import UnitSystem from galax.utils import ImmutableDict from galax.utils._jax import vectorize_method diff --git a/src/galax/dynamics/_dynamics/integrate/_funcs.py b/src/galax/dynamics/_dynamics/integrate/_funcs.py index 9d8128cd..22f4b480 100644 --- a/src/galax/dynamics/_dynamics/integrate/_funcs.py +++ b/src/galax/dynamics/_dynamics/integrate/_funcs.py @@ -114,7 +114,7 @@ def evaluate_orbit( >>> import quaxed.array_api as xp # preferred over `jax.numpy` >>> import galax.coordinates as gc >>> import galax.potential as gp - >>> from galax.units import galactic + >>> from unxt.unitsystems import galactic We can then create the point-mass' potential, with galactic units: diff --git a/src/galax/dynamics/_dynamics/mockstream/df/fardal.py b/src/galax/dynamics/_dynamics/mockstream/df/fardal.py index 7c15f690..09d700e6 100644 --- a/src/galax/dynamics/_dynamics/mockstream/df/fardal.py +++ b/src/galax/dynamics/_dynamics/mockstream/df/fardal.py @@ -152,8 +152,7 @@ def d2phidr2( -------- >>> from unxt import Quantity >>> from galax.potential import NFWPotential - >>> pot = NFWPotential(m=Quantity(1e12, "Msun"), r_s=Quantity(20.0, "kpc"), - ... units="galactic") + >>> pot = NFWPotential(m=1e12, r_s=20.0, units="galactic") >>> d2phidr2(pot, xp.asarray([8.0, 0.0, 0.0]), t=0) Array(-0.00259193, dtype=float64) """ @@ -250,8 +249,7 @@ def tidal_radius( Examples -------- >>> from galax.potential import NFWPotential - >>> from galax.units import galactic - >>> pot = NFWPotential(m=1e12, r_s=20.0, units=galactic) + >>> pot = NFWPotential(m=1e12, r_s=20.0, units="galactic") >>> x=xp.asarray([8.0, 0.0, 0.0]) >>> v=xp.asarray([8.0, 0.0, 0.0]) >>> tidal_radius(pot, x, v, prog_mass=1e4, t=0) diff --git a/src/galax/dynamics/_dynamics/mockstream/mockstream_generator.py b/src/galax/dynamics/_dynamics/mockstream/mockstream_generator.py index 63396153..d7c6e252 100644 --- a/src/galax/dynamics/_dynamics/mockstream/mockstream_generator.py +++ b/src/galax/dynamics/_dynamics/mockstream/mockstream_generator.py @@ -13,7 +13,7 @@ from jax.lib.xla_bridge import get_backend import quaxed.array_api as xp -from unxt import Quantity +from unxt import Quantity, UnitSystem from .core import MockStream from .df import AbstractStreamDF @@ -24,7 +24,6 @@ from galax.dynamics._dynamics.integrate._funcs import evaluate_orbit from galax.potential._potential.base import AbstractPotentialBase from galax.typing import BatchVec6, FloatScalar, IntScalar, QVecTime, Vec6, VecN -from galax.units import UnitSystem Carry: TypeAlias = tuple[IntScalar, VecN, VecN] diff --git a/src/galax/potential/_potential/base.py b/src/galax/potential/_potential/base.py index 8caae1a3..a8437360 100644 --- a/src/galax/potential/_potential/base.py +++ b/src/galax/potential/_potential/base.py @@ -20,14 +20,13 @@ import quaxed.numpy as qnp import unxt from coordinax import Abstract3DVector, FourVector -from unxt import Quantity +from unxt import Quantity, UnitSystem import galax.typing as gt from .utils import _convert_from_3dvec, parse_to_quantity from galax.coordinates import AbstractPhaseSpacePosition, PhaseSpacePosition from galax.potential._potential.param.attr import ParametersAttribute from galax.potential._potential.param.utils import all_parameters -from galax.units import UnitSystem from galax.utils._collections import ImmutableDict from galax.utils._jax import vectorize_method from galax.utils._shape import batched_shape, expand_arr_dims, expand_batch_dims diff --git a/src/galax/potential/_potential/builtin.py b/src/galax/potential/_potential/builtin.py index 45b7100a..d7d5933c 100644 --- a/src/galax/potential/_potential/builtin.py +++ b/src/galax/potential/_potential/builtin.py @@ -20,13 +20,12 @@ from quax import quaxify import quaxed.array_api as xp -from unxt import Quantity +from unxt import Quantity, UnitSystem, unitsystem import galax.typing as gt from galax.potential._potential.base import default_constants from galax.potential._potential.core import AbstractPotential from galax.potential._potential.param import AbstractParameter, ParameterField -from galax.units import UnitSystem, unitsystem from galax.utils import ImmutableDict from galax.utils._jax import vectorize_method @@ -266,10 +265,10 @@ class TriaxialHernquistPotential(AbstractPotential): or constant, like a Quantity. See :class:`~galax.potential.ParameterField` for details. - units : :class:`~galax.units.UnitSystem`, keyword-only + units : :class:`~unxt.UnitSystem`, keyword-only The unit system to use for the potential. This parameter accepts a - :class:`~galax.units.UnitSystem` or anything that can be converted to a - :class:`~galax.units.UnitSystem` using :func:`~galax.units.unitsystem`. + :class:`~unxt.UnitSystem` or anything that can be converted to a + :class:`~unxt.UnitSystem` using :func:`~unxt.unitsystem`. Examples -------- diff --git a/src/galax/potential/_potential/composite.py b/src/galax/potential/_potential/composite.py index 464472b5..58d192e6 100644 --- a/src/galax/potential/_potential/composite.py +++ b/src/galax/potential/_potential/composite.py @@ -10,11 +10,10 @@ import jax import quaxed.array_api as xp -from unxt import Quantity +from unxt import Quantity, UnitSystem, unitsystem from .base import AbstractPotentialBase, default_constants from galax.typing import BatchableRealQScalar, BatchFloatQScalar, BatchQVec3 -from galax.units import UnitSystem, unitsystem from galax.utils import ImmutableDict from galax.utils._misc import first diff --git a/src/galax/potential/_potential/core.py b/src/galax/potential/_potential/core.py index 7d84d57a..0142a7c8 100644 --- a/src/galax/potential/_potential/core.py +++ b/src/galax/potential/_potential/core.py @@ -7,12 +7,11 @@ import equinox as eqx -from unxt import Quantity +from unxt import Quantity, UnitSystem, unitsystem from .base import AbstractPotentialBase, default_constants from .composite import CompositePotential from galax.typing import FloatQScalar, QVec3, RealScalar -from galax.units import UnitSystem, unitsystem from galax.utils import ImmutableDict diff --git a/src/galax/potential/_potential/frame.py b/src/galax/potential/_potential/frame.py index 5c99b702..9b0b93ee 100644 --- a/src/galax/potential/_potential/frame.py +++ b/src/galax/potential/_potential/frame.py @@ -9,11 +9,10 @@ import equinox as eqx from coordinax.operators import OperatorSequence, simplify_op -from unxt import Quantity +from unxt import Quantity, UnitSystem import galax.typing as gt from galax.potential._potential.base import AbstractPotentialBase -from galax.units import UnitSystem from galax.utils import ImmutableDict @@ -90,7 +89,7 @@ class PotentialFrame(AbstractPotentialBase): >>> op2 = cxo.GalileanTranslationOperator(Quantity([1_000, 0, 0, 0], "kpc")) >>> op2.translation.t.to("Myr") - Quantity['time'](Array(3.26156378, dtype=float64), unit='Myr') + Quantity['time'](Array(3.26156366, dtype=float64), unit='Myr') >>> framedpot2 = gp.PotentialFrame(potential=pot, operator=op2) diff --git a/src/galax/potential/_potential/special.py b/src/galax/potential/_potential/special.py index 1a5c10bf..e870f8d0 100644 --- a/src/galax/potential/_potential/special.py +++ b/src/galax/potential/_potential/special.py @@ -11,11 +11,11 @@ import equinox as eqx from unxt import Quantity +from unxt.unitsystems import UnitSystem, dimensionless, galactic, unitsystem from .base import AbstractPotentialBase, default_constants from .builtin import HernquistPotential, MiyamotoNagaiPotential, NFWPotential from .composite import AbstractCompositePotential -from galax.units import UnitSystem, dimensionless, galactic, unitsystem from galax.utils import ImmutableDict T = TypeVar("T", bound=AbstractPotentialBase) @@ -52,7 +52,7 @@ class MilkyWayPotential(AbstractCompositePotential): Parameters ---------- - units : `~galax.units.UnitSystem` (optional) + units : `~unxt.UnitSystem` (optional) Set of non-reducable units that specify (at minimum) the length, mass, time, and angle units. disk : dict (optional) diff --git a/src/galax/potential/_potential/utils.py b/src/galax/potential/_potential/utils.py index 583e2a8c..8dac5b49 100644 --- a/src/galax/potential/_potential/utils.py +++ b/src/galax/potential/_potential/utils.py @@ -16,10 +16,10 @@ import coordinax as cx from unxt import Quantity +from unxt.unitsystems import DimensionlessUnitSystem, UnitSystem, dimensionless from galax.coordinates import AbstractPhaseSpacePosition from galax.typing import Unit -from galax.units import DimensionlessUnitSystem, UnitSystem, dimensionless # -------------------------------------------------------------- diff --git a/src/galax/units.py b/src/galax/units.py deleted file mode 100644 index b7486821..00000000 --- a/src/galax/units.py +++ /dev/null @@ -1,306 +0,0 @@ -"""Tools for representing systems of units using ``astropy.units``.""" - -__all__ = [ - "UnitSystem", - "DimensionlessUnitSystem", - "galactic", - "dimensionless", - "solarsystem", -] - -from collections.abc import Iterator -from typing import ClassVar, Literal, Union, cast, final - -import astropy.units as u -from astropy.units.physical import _physical_unit_mapping -from plum import dispatch - -from unxt import Quantity - -from galax.typing import Unit - - -class UnitSystem: - """Represents a system of units. - - At minimum, this consists of a set of length, time, mass, and angle units, but may - also contain preferred representations for composite units. For example, the base - unit system could be ``{kpc, Myr, Msun, radian}``, but you can also specify a - preferred velocity unit, such as ``km/s``. - - This class behaves like a dictionary with keys set by physical types (i.e. "length", - "velocity", "energy", etc.). If a unit for a particular physical type is not - specified on creation, a composite unit will be created with the base units. See the - examples below for some demonstrations. - - Parameters - ---------- - *units, **units - The units that define the unit system. At minimum, this must contain length, - time, mass, and angle units. If passing in keyword arguments, the keys must be - valid :mod:`astropy.units` physical types. - - Examples - -------- - If only base units are specified, any physical type specified as a key - to this object will be composed out of the base units:: - - >>> usys = UnitSystem(u.m, u.s, u.kg, u.radian) - >>> usys["velocity"] - Unit("m / s") - - However, preferred representations for composite units can also be specified:: - - >>> usys = UnitSystem(u.m, u.s, u.kg, u.radian, u.erg) - >>> usys["energy"] - Unit("m2 kg / s2") - >>> usys.preferred("energy") - Unit("erg") - - This is useful for Galactic dynamics where lengths and times are usually given in - terms of ``kpc`` and ``Myr``, but velocities are often specified in ``km/s``:: - - >>> usys = UnitSystem(u.kpc, u.Myr, u.Msun, u.radian, u.km/u.s) - >>> usys["velocity"] - Unit("kpc / Myr") - >>> usys.preferred("velocity") - Unit("km / s") - """ - - _core_units: list[u.UnitBase] - _registry: dict[u.PhysicalType, u.UnitBase] - - _required_dimensions: ClassVar[list[u.PhysicalType]] = [ - u.get_physical_type("length"), - u.get_physical_type("time"), - u.get_physical_type("mass"), - u.get_physical_type("angle"), - ] - - def __init__( - self, - units: Union[u.UnitBase, u.Quantity, Quantity, "UnitSystem"], - *args: u.UnitBase | u.Quantity | Quantity, - ) -> None: - if isinstance(units, UnitSystem): - if len(args) > 0: - msg = "If passing in a UnitSystem, cannot pass in additional units." - raise ValueError(msg) - - self._registry = units._registry.copy() # noqa: SLF001 - self._core_units = units._core_units # noqa: SLF001 - return - - units = (units, *args) - - self._registry = {} - for unit in units: - unit_ = ( # TODO: better detection of allowed unit base classes - unit if isinstance(unit, u.UnitBase) else u.def_unit(f"{unit!s}", unit) - ) - if unit_.physical_type in self._registry: - msg = f"Multiple units passed in with type {unit_.physical_type!r}" - raise ValueError(msg) - self._registry[unit_.physical_type] = unit_ - - self._core_units = [] - for phys_type in self._required_dimensions: - if phys_type not in self._registry: - msg = f"You must specify a unit for the physical type {phys_type!r}" - raise ValueError(msg) - self._core_units.append(self._registry[phys_type]) - - def __getitem__(self, key: str | u.PhysicalType) -> u.UnitBase: - key = u.get_physical_type(key) - if key in self._required_dimensions: - return self._registry[key] - - unit = None - for k, v in _physical_unit_mapping.items(): - if v == key: - unit = u.Unit(" ".join([f"{x}**{y}" for x, y in k])) - break - - if unit is None: - msg = f"Physical type '{key}' doesn't exist in unit registry." - raise ValueError(msg) - - unit = unit.decompose(self._core_units) - unit._scale = 1.0 # noqa: SLF001 - return unit - - def __len__(self) -> int: - # Note: This is required for q.decompose(usys) to work, where q is a Quantity - return len(self._core_units) - - def __iter__(self) -> Iterator[u.UnitBase]: - yield from self._core_units - - def __repr__(self) -> str: - return f"UnitSystem({', '.join(str(uu) for uu in self._core_units)})" - - def __eq__(self, other: object) -> bool: - if not isinstance(other, UnitSystem): - return NotImplemented - return bool(self._registry == other._registry) - - def __ne__(self, other: object) -> bool: - return not self.__eq__(other) - - def __hash__(self) -> int: - """Hash the unit system.""" - return hash(tuple(self._core_units) + tuple(self._required_dimensions)) - - def preferred(self, key: str | u.PhysicalType) -> u.UnitBase: - """Return the preferred unit for a given physical type.""" - key = u.get_physical_type(key) - if key in self._registry: - return self._registry[key] - return self[key] - - def as_preferred(self, quantity: Quantity | u.Quantity) -> Quantity: - """Convert a quantity to the preferred unit for this unit system.""" - unit = self.preferred(quantity.unit.physical_type) - return cast(Quantity, Quantity.constructor(quantity.to(unit), unit)) - - -@final -class DimensionlessUnitSystem(UnitSystem): - """A unit system with only dimensionless units.""" - - _required_dimensions: ClassVar[list[u.PhysicalType]] = [] - - def __init__(self) -> None: - super().__init__(u.one) - self._core_units = [u.one] - - def __getitem__(self, key: str | u.PhysicalType) -> u.UnitBase: - return u.one - - def __str__(self) -> str: - return "UnitSystem(dimensionless)" - - def __repr__(self) -> str: - return "DimensionlessUnitSystem()" - - -# define galactic unit system -galactic = UnitSystem(u.kpc, u.Myr, u.Msun, u.radian, u.km / u.s) - -# solar system units -solarsystem = UnitSystem(u.au, u.M_sun, u.yr, u.radian) - -# dimensionless -dimensionless = DimensionlessUnitSystem() - - -# =========================== -# Unit-system constructor - - -@dispatch -def unitsystem(units: UnitSystem, /) -> UnitSystem: - """Convert a UnitSystem or tuple of arguments to a UnitSystem. - - Examples - -------- - >>> import astropy.units as u - >>> from galax.units import UnitSystem, unitsystem - >>> usys = UnitSystem(u.kpc, u.Myr, u.Msun, u.radian, u.km/u.s) - >>> usys - UnitSystem(kpc, Myr, solMass, rad) - - >>> unitsystem(usys) - UnitSystem(kpc, Myr, solMass, rad) - """ - return units - - -@dispatch # type: ignore[no-redef] -def unitsystem( - units: ( - tuple[Unit | u.Quantity | Quantity, ...] | list[Unit | u.Quantity | Quantity] - ), - /, -) -> UnitSystem: - """Convert a UnitSystem or tuple of arguments to a UnitSystem. - - Examples - -------- - >>> import astropy.units as u - >>> from galax.units import UnitSystem, unitsystem - - >>> unitsystem((u.kpc, u.Myr, u.Msun, u.radian, u.km/u.s)) - UnitSystem(kpc, Myr, solMass, rad) - - >>> unitsystem([u.kpc, u.Myr, u.Msun, u.radian, u.km/u.s]) - UnitSystem(kpc, Myr, solMass, rad) - """ - return UnitSystem(*units) if len(units) > 0 else dimensionless - - -@dispatch # type: ignore[no-redef] -def unitsystem(_: None, /) -> UnitSystem: - """Dimensionless unit system from None. - - Examples - -------- - >>> from galax.units import unitsystem - >>> unitsystem(None) - DimensionlessUnitSystem() - """ - return dimensionless - - -@dispatch # type: ignore[no-redef] -def unitsystem(unit0: Unit, /, *units: Unit) -> UnitSystem: - """Convert a set of arguments to a UnitSystem. - - Examples - -------- - >>> import astropy.units as u - >>> from galax.units import UnitSystem, unitsystem - - >>> unitsystem(u.kpc, u.Myr, u.Msun, u.radian) - UnitSystem(kpc, Myr, solMass, rad) - """ - return UnitSystem(unit0, *units) - - -@dispatch # type: ignore[no-redef] -def unitsystem(_: Literal["galactic"], /) -> UnitSystem: - """Galactic unit system by string. - - Examples - -------- - >>> from galax.units import unitsystem - >>> unitsystem("galactic") - UnitSystem(kpc, Myr, solMass, rad) - """ - return galactic - - -@dispatch # type: ignore[no-redef] -def unitsystem(_: Literal["solarsystem"], /) -> UnitSystem: - """Solar system unit system by string. - - Examples - -------- - >>> from galax.units import unitsystem - >>> unitsystem("solarsystem") - UnitSystem(AU, yr, solMass, rad) - """ - return solarsystem - - -@dispatch # type: ignore[no-redef] -def unitsystem(_: Literal["dimensionless"], /) -> UnitSystem: - """Dimensionless unit system by string. - - Examples - -------- - >>> from galax.units import unitsystem - >>> unitsystem("dimensionless") - DimensionlessUnitSystem() - """ - return dimensionless diff --git a/tests/functional/test_mockstreamgenerator.py b/tests/functional/test_mockstreamgenerator.py index 78a9bc0f..bf6b094a 100644 --- a/tests/functional/test_mockstreamgenerator.py +++ b/tests/functional/test_mockstreamgenerator.py @@ -8,12 +8,11 @@ import quax.examples.prng as jr import quaxed.array_api as xp -from unxt import Quantity +from unxt import Quantity, UnitSystem from galax.dynamics import FardalStreamDF, MockStreamGenerator from galax.potential import MilkyWayPotential from galax.typing import FloatQScalar, FloatScalar, QVecTime, Vec6 -from galax.units import UnitSystem usys = UnitSystem(u.kpc, u.Myr, u.Msun, u.radian) df = FardalStreamDF() diff --git a/tests/smoke/test_package.py b/tests/smoke/test_package.py index b577dc06..e4d2ad39 100644 --- a/tests/smoke/test_package.py +++ b/tests/smoke/test_package.py @@ -18,6 +18,5 @@ def test_all() -> None: "dynamics", "potential", "typing", - "units", "utils", } diff --git a/tests/unit/coordinates/psp/test_base.py b/tests/unit/coordinates/psp/test_base.py index fe950413..3c394ecc 100644 --- a/tests/unit/coordinates/psp/test_base.py +++ b/tests/unit/coordinates/psp/test_base.py @@ -18,13 +18,13 @@ import quaxed.array_api as xp from coordinax import Cartesian3DVector, CartesianDifferential3D from unxt import Quantity +from unxt.unitsystems import galactic from galax.coordinates import AbstractPhaseSpacePosition from galax.coordinates._psp.psp import ComponentShapeTuple from galax.coordinates._psp.utils import _p_converter, _q_converter from galax.potential import AbstractPotentialBase, KeplerPotential from galax.potential._potential.special import MilkyWayPotential -from galax.units import galactic if TYPE_CHECKING: from pytest import FixtureRequest # noqa: PT013 @@ -248,7 +248,7 @@ def wt(self, *, units: UnitSystem | None = None) -> BatchVec7: Parameters ---------- - units : `galax.units.UnitSystem`, optional keyword-only + units : `unxt.UnitSystem`, optional keyword-only The unit system If ``None``, use the current unit system. Returns diff --git a/tests/unit/dynamics/test_orbit.py b/tests/unit/dynamics/test_orbit.py index c58c390d..5e8785bf 100644 --- a/tests/unit/dynamics/test_orbit.py +++ b/tests/unit/dynamics/test_orbit.py @@ -9,6 +9,7 @@ import quaxed.array_api as xp from unxt import Quantity +from unxt.unitsystems import galactic from ..coordinates.psp.test_base import ( AbstractPhaseSpacePosition_Test, @@ -18,7 +19,6 @@ from galax.coordinates import PhaseSpacePosition from galax.dynamics import Orbit from galax.potential import AbstractPotentialBase, MilkyWayPotential -from galax.units import galactic T = TypeVar("T", bound=Orbit) diff --git a/tests/unit/potential/builtin/test_bar.py b/tests/unit/potential/builtin/test_bar.py index 8fa106cf..b85923c6 100644 --- a/tests/unit/potential/builtin/test_bar.py +++ b/tests/unit/potential/builtin/test_bar.py @@ -5,7 +5,7 @@ import pytest import quaxed.numpy as qnp -from unxt import Quantity +from unxt import Quantity, UnitSystem import galax.typing as gt from ..test_core import TestAbstractPotential as AbstractPotential_Test @@ -16,7 +16,6 @@ ShapeCParameterMixin, ) from galax.potential import AbstractPotentialBase, BarPotential -from galax.units import UnitSystem class TestBarPotential( diff --git a/tests/unit/potential/builtin/test_common.py b/tests/unit/potential/builtin/test_common.py index 644544cc..086ef411 100644 --- a/tests/unit/potential/builtin/test_common.py +++ b/tests/unit/potential/builtin/test_common.py @@ -2,11 +2,11 @@ import pytest from unxt import Quantity +from unxt.unitsystems import galactic import galax.potential as gp from ..param.test_field import ParameterFieldMixin from galax.potential import ConstantParameter -from galax.units import galactic class MassParameterMixin(ParameterFieldMixin): diff --git a/tests/unit/potential/builtin/test_miyamotonagai.py b/tests/unit/potential/builtin/test_miyamotonagai.py index cec79d0d..117935c2 100644 --- a/tests/unit/potential/builtin/test_miyamotonagai.py +++ b/tests/unit/potential/builtin/test_miyamotonagai.py @@ -4,14 +4,13 @@ import pytest import quaxed.numpy as qnp -from unxt import Quantity +from unxt import Quantity, UnitSystem import galax.potential as gp from ..test_core import TestAbstractPotential as AbstractPotential_Test from .test_common import MassParameterMixin, ShapeAParameterMixin, ShapeBParameterMixin from galax.potential import AbstractPotentialBase, MiyamotoNagaiPotential from galax.typing import Vec3 -from galax.units import UnitSystem class TestMiyamotoNagaiPotential( diff --git a/tests/unit/potential/builtin/test_nfw.py b/tests/unit/potential/builtin/test_nfw.py index 4dafcb1d..1c87a815 100644 --- a/tests/unit/potential/builtin/test_nfw.py +++ b/tests/unit/potential/builtin/test_nfw.py @@ -6,13 +6,13 @@ import quaxed.numpy as qnp from unxt import Quantity +from unxt.unitsystems import UnitSystem, galactic import galax.potential as gp import galax.typing as gt from ..param.test_field import ParameterFieldMixin from ..test_core import TestAbstractPotential as AbstractPotential_Test from .test_common import MassParameterMixin -from galax.units import UnitSystem, galactic from galax.utils._optional_deps import HAS_GALA diff --git a/tests/unit/potential/builtin/test_null.py b/tests/unit/potential/builtin/test_null.py index 3f0bd5d7..1d955999 100644 --- a/tests/unit/potential/builtin/test_null.py +++ b/tests/unit/potential/builtin/test_null.py @@ -7,13 +7,12 @@ from typing_extensions import override import quaxed.numpy as qnp -from unxt import Quantity +import unxt.unitsystems as usx +from unxt import Quantity, UnitSystem import galax.potential as gp import galax.typing as gt -import galax.units as gu from ..test_core import TestAbstractPotential as AbstractPotential_Test -from galax.units import UnitSystem, dimensionless class TestNullPotential(AbstractPotential_Test): @@ -37,7 +36,7 @@ def test_init_units_from_args( # and a numeric value works. fields_unitless.pop("units", None) pot = pot_cls(**fields_unitless, units=None) - assert pot.units == dimensionless + assert pot.units == usx.dimensionless @override def test_init_units_from_name( @@ -47,13 +46,13 @@ def test_init_units_from_name( fields_unitless.pop("units") pot = pot_cls(**fields_unitless, units="dimensionless") - assert pot.units == gu.dimensionless + assert pot.units == usx.dimensionless pot = pot_cls(**fields_unitless, units="solarsystem") - assert pot.units == gu.solarsystem + assert pot.units == usx.solarsystem pot = pot_cls(**fields_unitless, units="galactic") - assert pot.units == gu.galactic + assert pot.units == usx.galactic msg = "`unitsystem('invalid_value')` could not be resolved." with pytest.raises(NotFoundLookupError, match=re.escape(msg)): diff --git a/tests/unit/potential/io/gala_helper.py b/tests/unit/potential/io/gala_helper.py index 18b06351..c09f84c1 100644 --- a/tests/unit/potential/io/gala_helper.py +++ b/tests/unit/potential/io/gala_helper.py @@ -14,9 +14,10 @@ from gala.units import UnitSystem as GalaUnitSystem, dimensionless as gala_dimensionless from plum import convert +from unxt.unitsystems import DimensionlessUnitSystem, UnitSystem + import galax.potential as gp from galax.potential._potential.io.gala import _GALA_TO_GALAX_REGISTRY -from galax.units import DimensionlessUnitSystem, UnitSystem ############################################################################## # UnitSystem diff --git a/tests/unit/potential/test_base.py b/tests/unit/potential/test_base.py index 2594de16..db0d06af 100644 --- a/tests/unit/potential/test_base.py +++ b/tests/unit/potential/test_base.py @@ -9,14 +9,14 @@ import quaxed.array_api as xp import quaxed.numpy as qnp -from unxt import Quantity +from unxt import Quantity, UnitSystem +from unxt.unitsystems import galactic import galax.dynamics as gd import galax.typing as gt from .io.test_gala import GalaIOMixin from galax.potential import AbstractParameter, AbstractPotentialBase, ParameterField from galax.potential._potential.base import default_constants -from galax.units import UnitSystem, galactic from galax.utils import ImmutableDict diff --git a/tests/unit/potential/test_composite.py b/tests/unit/potential/test_composite.py index 83fa0306..81f16d68 100644 --- a/tests/unit/potential/test_composite.py +++ b/tests/unit/potential/test_composite.py @@ -10,6 +10,7 @@ import quaxed.array_api as xp import quaxed.numpy as qnp from unxt import Quantity +from unxt.unitsystems import UnitSystem, dimensionless, galactic, solarsystem from .test_base import TestAbstractPotentialBase as AbstractPotentialBase_Test from .test_utils import FieldUnitSystemMixin @@ -21,7 +22,6 @@ NFWPotential, ) from galax.typing import Vec3 -from galax.units import UnitSystem, dimensionless, galactic, solarsystem from galax.utils._misc import first diff --git a/tests/unit/potential/test_core.py b/tests/unit/potential/test_core.py index 6775e1ff..8dc0a2fd 100644 --- a/tests/unit/potential/test_core.py +++ b/tests/unit/potential/test_core.py @@ -7,13 +7,13 @@ import quaxed.array_api as xp from unxt import Quantity +from unxt.unitsystems import UnitSystem, galactic, unitsystem import galax.potential as gp import galax.typing as gt from .test_base import TestAbstractPotentialBase as AbstractPotentialBase_Test from .test_utils import FieldUnitSystemMixin from galax.potential._potential.base import default_constants -from galax.units import UnitSystem, galactic, unitsystem from galax.utils import ImmutableDict diff --git a/tests/unit/potential/test_special.py b/tests/unit/potential/test_special.py index 0f320a85..c00fa8b3 100644 --- a/tests/unit/potential/test_special.py +++ b/tests/unit/potential/test_special.py @@ -8,6 +8,7 @@ import quaxed.numpy as qnp from unxt import Quantity +from unxt.unitsystems import UnitSystem, galactic, solarsystem import galax.typing as gt from .test_composite import AbstractCompositePotential_Test @@ -17,7 +18,6 @@ KeplerPotential, MilkyWayPotential, ) -from galax.units import UnitSystem, galactic, solarsystem from galax.utils._misc import first ############################################################################## diff --git a/tests/unit/potential/test_utils.py b/tests/unit/potential/test_utils.py index 52ea895d..4836a570 100644 --- a/tests/unit/potential/test_utils.py +++ b/tests/unit/potential/test_utils.py @@ -9,8 +9,10 @@ from jax import Array from plum import NotFoundLookupError +from unxt import UnitSystem, unitsystem +from unxt.unitsystems import dimensionless, galactic, solarsystem + from galax.potential import AbstractPotentialBase -from galax.units import UnitSystem, dimensionless, galactic, solarsystem, unitsystem from galax.utils._optional_deps import HAS_GALA diff --git a/tests/unit/test_units.py b/tests/unit/test_units.py deleted file mode 100644 index 90bb2908..00000000 --- a/tests/unit/test_units.py +++ /dev/null @@ -1,99 +0,0 @@ -import pickle -from pathlib import Path - -import astropy.units as u -import numpy as np -import pytest - -from galax.units import UnitSystem, dimensionless - - -class TestUnitSystem: - """Test :class:`~galax.units.UnitSystem`.""" - - def test_constructor(self) -> None: - """Test the :class:`~galax.units.UnitSystem` constructor.""" - usys = UnitSystem(u.kpc, u.Myr, u.radian, u.Msun) - - match = "must specify a unit for the physical type .*mass" - with pytest.raises(ValueError, match=match): - UnitSystem(u.kpc, u.Myr, u.radian) # no mass - - match = "must specify a unit for the physical type .*angle" - with pytest.raises(ValueError, match=match): - UnitSystem(u.kpc, u.Myr, u.Msun) - - match = "must specify a unit for the physical type .*time" - with pytest.raises(ValueError, match=match): - UnitSystem(u.kpc, u.radian, u.Msun) - - match = "must specify a unit for the physical type .*length" - with pytest.raises(ValueError, match=match): - UnitSystem(u.Myr, u.radian, u.Msun) - - usys = UnitSystem(u.kpc, u.Myr, u.radian, u.Msun) - usys = UnitSystem(usys) - - def test_constructor_quantity(self) -> None: - """Test the :class:`~galax.units.UnitSystem` constructor with quantities.""" - usys = UnitSystem(5 * u.kpc, 50 * u.Myr, 1e5 * u.Msun, u.rad) - assert np.isclose((8 * u.Myr).decompose(usys).value, 8 / 50) - - def test_preferred(self) -> None: - """Test the :meth:`~galax.units.UnitSystem.preferred` method.""" - usys = UnitSystem(u.kpc, u.Myr, u.radian, u.Msun, u.km / u.s) - q = 15.0 * u.km / u.s - assert usys.preferred("velocity") == u.km / u.s - assert q.decompose(usys).unit == u.kpc / u.Myr - assert usys.as_preferred(q).unit == u.km / u.s - - # =============================================================== - - def test_compare(self) -> None: - """Test the :meth:`~galax.units.UnitSystem.compare` method.""" - usys1 = UnitSystem(u.kpc, u.Myr, u.radian, u.Msun, u.mas / u.yr) - usys1_clone = UnitSystem(u.kpc, u.Myr, u.radian, u.Msun, u.mas / u.yr) - - usys2 = UnitSystem(u.kpc, u.Myr, u.radian, u.Msun, u.kiloarcsecond / u.yr) - usys3 = UnitSystem(u.kpc, u.Myr, u.radian, u.kg, u.mas / u.yr) - - assert usys1 == usys1_clone - assert usys1_clone == usys1 - - assert usys1 != usys2 - assert usys2 != usys1 - - assert usys1 != usys3 - assert usys3 != usys1 - - def test_pickle(self, tmpdir: Path) -> None: - """Test pickling and unpickling a :class:`~galax.units.UnitSystem`.""" - usys = UnitSystem(u.kpc, u.Myr, u.radian, u.Msun) - - path = tmpdir / "test.pkl" - with path.open(mode="wb") as f: - pickle.dump(usys, f) - - with path.open(mode="rb") as f: - usys2 = pickle.load(f) - - assert usys == usys2 - - -class TestDimensionlessUnitSystem: - """Test :class:`~galax.units.DimensionlessUnitSystem`.""" - - def test_getitem(self) -> None: - """Test :meth:`~galax.units.DimensionlessUnitSystem.__getitem__`.""" - assert dimensionless["dimensionless"] == u.one - assert dimensionless["length"] == u.one - - def test_decompose(self) -> None: - """Test that dimensionless unitsystem can be decomposed.""" - with pytest.raises(ValueError, match="can not be decomposed into"): - (15 * u.kpc).decompose(dimensionless) - - def test_preferred(self) -> None: - """Test the :meth:`~galax.units.DimensionlessUnitSystem.preferred` method.""" - with pytest.raises(ValueError, match="are not convertible"): - dimensionless.as_preferred(15 * u.kpc)