Skip to content

Commit

Permalink
feat: move units to unxt (GalacticDynamics#224)
Browse files Browse the repository at this point in the history
* feat: move units to unxt
* fix: rebase and smoke
* fix: digits

Signed-off-by: nstarman <nstarman@users.noreply.github.com>
  • Loading branch information
nstarman committed Mar 17, 2024
1 parent 4fe1e99 commit 1a58b68
Show file tree
Hide file tree
Showing 36 changed files with 54 additions and 473 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
"Typing :: Typed",
]
dependencies = [
"astropy >= 6",
"astropy >= 6.0",
"beartype",
"coordinax @ git+https://github.com/GalacticDynamics/coordinax.git",
"diffrax >= 0.5",
Expand Down
3 changes: 1 addition & 2 deletions src/galax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,13 @@
"coordinates",
"potential",
"dynamics",
"units",
"utils",
"typing",
]

from jax import config

from . import coordinates, dynamics, potential, typing, units, utils
from . import coordinates, dynamics, potential, typing, utils
from ._version import __version__

config.update("jax_enable_x64", True) # noqa: FBT003
Expand Down
2 changes: 0 additions & 2 deletions src/galax/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,13 @@ __all__ = [
"dynamics",
"potential",
"typing",
"units",
"utils",
]

from . import (
dynamics as dynamics,
potential as potential,
typing as typing,
units as units,
utils as utils,
)
from ._version import ( # type: ignore[attr-defined]
Expand Down
11 changes: 5 additions & 6 deletions src/galax/coordinates/_psp/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
Cartesian3DVector,
represent_as as vector_represent_as,
)
from unxt import Quantity
from unxt import Quantity, unitsystem

from .utils import getitem_broadscalartime_index
from galax.typing import (
Expand All @@ -30,7 +30,6 @@
BatchVec7,
BroadBatchFloatQScalar,
)
from galax.units import unitsystem

if TYPE_CHECKING:
from typing import Self
Expand Down Expand Up @@ -188,8 +187,8 @@ def w(self, *, units: Any) -> BatchVec6:
Parameters
----------
units : `galax.units.UnitSystem`, optional keyword-only
The unit system. :func:`~galax.units.unitsystem` is used to
units : `unxt.UnitSystem`, optional keyword-only
The unit system. :func:`~unxt.unitsystem` is used to
convert the input to a unit system.
Returns
Expand Down Expand Up @@ -228,8 +227,8 @@ def wt(self, *, units: Any) -> BatchVec7:
Parameters
----------
units : `galax.units.UnitSystem`, keyword-only
The unit system. :func:`~galax.units.unitsystem` is used to
units : `unxt.UnitSystem`, keyword-only
The unit system. :func:`~unxt.unitsystem` is used to
convert the input to a unit system.
Returns
Expand Down
2 changes: 1 addition & 1 deletion src/galax/coordinates/_psp/operator_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def call(
Quantity['length'](Array(2., dtype=float64), unit='kpc')
>>> newpsp.t.to("Myr")
Quantity['time'](Array(6.52312755, dtype=float64), unit='Myr')
Quantity['time'](Array(6.52312732, dtype=float64), unit='Myr')
This spatial translation is time independent.
Expand Down
3 changes: 1 addition & 2 deletions src/galax/coordinates/_psp/psp.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import jax.numpy as jnp

from coordinax import Abstract3DVector, Abstract3DVectorDifferential
from unxt import Quantity
from unxt import Quantity, UnitSystem

from .base import AbstractPhaseSpacePosition
from .utils import _p_converter, _q_converter
Expand All @@ -20,7 +20,6 @@
QVec1,
VecTime,
)
from galax.units import UnitSystem
from galax.utils._shape import batched_shape, expand_batch_dims, vector_batched_shape


Expand Down
3 changes: 2 additions & 1 deletion src/galax/dynamics/_dynamics/integrate/_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@

from typing import Any, Protocol, runtime_checkable

from unxt import UnitSystem

import galax.typing as gt
from galax.coordinates import AbstractPhaseSpacePosition, PhaseSpacePosition
from galax.units import UnitSystem
from galax.utils.dataclasses import _DataclassInstance


Expand Down
3 changes: 2 additions & 1 deletion src/galax/dynamics/_dynamics/integrate/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@

import equinox as eqx

from unxt import UnitSystem

from ._api import FCallable
from galax.coordinates import AbstractPhaseSpacePosition, PhaseSpacePosition
from galax.typing import BatchQVecTime, BatchVec6, BatchVecTime, QVecTime, VecTime
from galax.units import UnitSystem


class AbstractIntegrator(eqx.Module, strict=True): # type: ignore[call-arg, misc]
Expand Down
3 changes: 1 addition & 2 deletions src/galax/dynamics/_dynamics/integrate/_builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,12 @@
import jax

import quaxed.array_api as xp
from unxt import Quantity
from unxt import Quantity, UnitSystem

import galax.typing as gt
from ._api import FCallable
from ._base import AbstractIntegrator
from galax.coordinates import AbstractPhaseSpacePosition, PhaseSpacePosition
from galax.units import UnitSystem
from galax.utils import ImmutableDict
from galax.utils._jax import vectorize_method

Expand Down
2 changes: 1 addition & 1 deletion src/galax/dynamics/_dynamics/integrate/_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def evaluate_orbit(
>>> import quaxed.array_api as xp # preferred over `jax.numpy`
>>> import galax.coordinates as gc
>>> import galax.potential as gp
>>> from galax.units import galactic
>>> from unxt.unitsystems import galactic
We can then create the point-mass' potential, with galactic units:
Expand Down
6 changes: 2 additions & 4 deletions src/galax/dynamics/_dynamics/mockstream/df/fardal.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,8 +152,7 @@ def d2phidr2(
--------
>>> from unxt import Quantity
>>> from galax.potential import NFWPotential
>>> pot = NFWPotential(m=Quantity(1e12, "Msun"), r_s=Quantity(20.0, "kpc"),
... units="galactic")
>>> pot = NFWPotential(m=1e12, r_s=20.0, units="galactic")
>>> d2phidr2(pot, xp.asarray([8.0, 0.0, 0.0]), t=0)
Array(-0.00259193, dtype=float64)
"""
Expand Down Expand Up @@ -250,8 +249,7 @@ def tidal_radius(
Examples
--------
>>> from galax.potential import NFWPotential
>>> from galax.units import galactic
>>> pot = NFWPotential(m=1e12, r_s=20.0, units=galactic)
>>> pot = NFWPotential(m=1e12, r_s=20.0, units="galactic")
>>> x=xp.asarray([8.0, 0.0, 0.0])
>>> v=xp.asarray([8.0, 0.0, 0.0])
>>> tidal_radius(pot, x, v, prog_mass=1e4, t=0)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from jax.lib.xla_bridge import get_backend

import quaxed.array_api as xp
from unxt import Quantity
from unxt import Quantity, UnitSystem

from .core import MockStream
from .df import AbstractStreamDF
Expand All @@ -24,7 +24,6 @@
from galax.dynamics._dynamics.integrate._funcs import evaluate_orbit
from galax.potential._potential.base import AbstractPotentialBase
from galax.typing import BatchVec6, FloatScalar, IntScalar, QVecTime, Vec6, VecN
from galax.units import UnitSystem

Carry: TypeAlias = tuple[IntScalar, VecN, VecN]

Expand Down
3 changes: 1 addition & 2 deletions src/galax/potential/_potential/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,13 @@
import quaxed.numpy as qnp
import unxt
from coordinax import Abstract3DVector, FourVector
from unxt import Quantity
from unxt import Quantity, UnitSystem

import galax.typing as gt
from .utils import _convert_from_3dvec, parse_to_quantity
from galax.coordinates import AbstractPhaseSpacePosition, PhaseSpacePosition
from galax.potential._potential.param.attr import ParametersAttribute
from galax.potential._potential.param.utils import all_parameters
from galax.units import UnitSystem
from galax.utils._collections import ImmutableDict
from galax.utils._jax import vectorize_method
from galax.utils._shape import batched_shape, expand_arr_dims, expand_batch_dims
Expand Down
9 changes: 4 additions & 5 deletions src/galax/potential/_potential/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,12 @@
from quax import quaxify

import quaxed.array_api as xp
from unxt import Quantity
from unxt import Quantity, UnitSystem, unitsystem

import galax.typing as gt
from galax.potential._potential.base import default_constants
from galax.potential._potential.core import AbstractPotential
from galax.potential._potential.param import AbstractParameter, ParameterField
from galax.units import UnitSystem, unitsystem
from galax.utils import ImmutableDict
from galax.utils._jax import vectorize_method

Expand Down Expand Up @@ -266,10 +265,10 @@ class TriaxialHernquistPotential(AbstractPotential):
or constant, like a Quantity. See
:class:`~galax.potential.ParameterField` for details.
units : :class:`~galax.units.UnitSystem`, keyword-only
units : :class:`~unxt.UnitSystem`, keyword-only
The unit system to use for the potential. This parameter accepts a
:class:`~galax.units.UnitSystem` or anything that can be converted to a
:class:`~galax.units.UnitSystem` using :func:`~galax.units.unitsystem`.
:class:`~unxt.UnitSystem` or anything that can be converted to a
:class:`~unxt.UnitSystem` using :func:`~unxt.unitsystem`.
Examples
--------
Expand Down
3 changes: 1 addition & 2 deletions src/galax/potential/_potential/composite.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,10 @@
import jax

import quaxed.array_api as xp
from unxt import Quantity
from unxt import Quantity, UnitSystem, unitsystem

from .base import AbstractPotentialBase, default_constants
from galax.typing import BatchableRealQScalar, BatchFloatQScalar, BatchQVec3
from galax.units import UnitSystem, unitsystem
from galax.utils import ImmutableDict
from galax.utils._misc import first

Expand Down
3 changes: 1 addition & 2 deletions src/galax/potential/_potential/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,11 @@

import equinox as eqx

from unxt import Quantity
from unxt import Quantity, UnitSystem, unitsystem

from .base import AbstractPotentialBase, default_constants
from .composite import CompositePotential
from galax.typing import FloatQScalar, QVec3, RealScalar
from galax.units import UnitSystem, unitsystem
from galax.utils import ImmutableDict


Expand Down
5 changes: 2 additions & 3 deletions src/galax/potential/_potential/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,10 @@
import equinox as eqx

from coordinax.operators import OperatorSequence, simplify_op
from unxt import Quantity
from unxt import Quantity, UnitSystem

import galax.typing as gt
from galax.potential._potential.base import AbstractPotentialBase
from galax.units import UnitSystem
from galax.utils import ImmutableDict


Expand Down Expand Up @@ -90,7 +89,7 @@ class PotentialFrame(AbstractPotentialBase):
>>> op2 = cxo.GalileanTranslationOperator(Quantity([1_000, 0, 0, 0], "kpc"))
>>> op2.translation.t.to("Myr")
Quantity['time'](Array(3.26156378, dtype=float64), unit='Myr')
Quantity['time'](Array(3.26156366, dtype=float64), unit='Myr')
>>> framedpot2 = gp.PotentialFrame(potential=pot, operator=op2)
Expand Down
4 changes: 2 additions & 2 deletions src/galax/potential/_potential/special.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@
import equinox as eqx

from unxt import Quantity
from unxt.unitsystems import UnitSystem, dimensionless, galactic, unitsystem

from .base import AbstractPotentialBase, default_constants
from .builtin import HernquistPotential, MiyamotoNagaiPotential, NFWPotential
from .composite import AbstractCompositePotential
from galax.units import UnitSystem, dimensionless, galactic, unitsystem
from galax.utils import ImmutableDict

T = TypeVar("T", bound=AbstractPotentialBase)
Expand Down Expand Up @@ -52,7 +52,7 @@ class MilkyWayPotential(AbstractCompositePotential):
Parameters
----------
units : `~galax.units.UnitSystem` (optional)
units : `~unxt.UnitSystem` (optional)
Set of non-reducable units that specify (at minimum) the
length, mass, time, and angle units.
disk : dict (optional)
Expand Down
2 changes: 1 addition & 1 deletion src/galax/potential/_potential/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@

import coordinax as cx
from unxt import Quantity
from unxt.unitsystems import DimensionlessUnitSystem, UnitSystem, dimensionless

from galax.coordinates import AbstractPhaseSpacePosition
from galax.typing import Unit
from galax.units import DimensionlessUnitSystem, UnitSystem, dimensionless

# --------------------------------------------------------------

Expand Down
Loading

0 comments on commit 1a58b68

Please sign in to comment.