diff --git a/pyproject.toml b/pyproject.toml index 37876fac..8fd9fdac 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,7 +25,7 @@ "Typing :: Typed", ] dependencies = [ - "quaxed >=0.2", + "quaxed >=0.3", "astropy >= 5.3", "beartype", "diffrax", diff --git a/src/galax/coordinates/_psp/operator_compat.py b/src/galax/coordinates/_psp/operator_compat.py index 934fbe08..26759ff3 100644 --- a/src/galax/coordinates/_psp/operator_compat.py +++ b/src/galax/coordinates/_psp/operator_compat.py @@ -6,8 +6,8 @@ import jax.numpy as jnp from plum import convert -from quax import quaxify +import quaxed.numpy as qnp from coordinax import CartesianDifferential3D from coordinax.operators import ( AbstractCompositeOperator, @@ -301,7 +301,7 @@ def call( return (replace(psp, q=q, p=p), t) -vec_matmul = quaxify(jnp.vectorize(jnp.matmul, signature="(3,3),(3)->(3)")) +vec_matmul = qnp.vectorize(jnp.matmul, signature="(3,3),(3)->(3)") @op_call_dispatch diff --git a/tests/unit/dynamics/test_orbit.py b/tests/unit/dynamics/test_orbit.py index 45eefee3..3e594efa 100644 --- a/tests/unit/dynamics/test_orbit.py +++ b/tests/unit/dynamics/test_orbit.py @@ -4,7 +4,6 @@ import jax.random as jr import pytest from plum import convert -from quax import quaxify import quaxed.array_api as xp from jax_quantity import Quantity @@ -16,8 +15,6 @@ from galax.potential import AbstractPotentialBase, MilkyWayPotential from galax.units import galactic -array_equal = quaxify(jnp.array_equal) - class TestOrbit(AbstractPhaseSpaceTimePosition_Test[Orbit]): """Test :class:`~galax.coordinates.PhaseSpacePosition`.""" diff --git a/tests/unit/potential/builtin/test_bar.py b/tests/unit/potential/builtin/test_bar.py index 44ca866e..bfb72158 100644 --- a/tests/unit/potential/builtin/test_bar.py +++ b/tests/unit/potential/builtin/test_bar.py @@ -3,9 +3,9 @@ import astropy.units as u import jax.numpy as jnp import pytest -from quax import quaxify import quaxed.array_api as xp +import quaxed.numpy as qnp from jax_quantity import Quantity from ..test_core import TestAbstractPotential as AbstractPotential_Test @@ -19,8 +19,6 @@ from galax.typing import Vec3 from galax.units import UnitSystem -allclose = quaxify(jnp.allclose) - class TestBarPotential( AbstractPotential_Test, @@ -66,7 +64,9 @@ def test_gradient(self, pot: BarPotential, x: Vec3) -> None: expected = Quantity( [0.04011905, 0.08383918, 0.16552719], pot.units["acceleration"] ) - assert allclose(pot.gradient(x, t=0).value, expected.value) # TODO: not .value + assert qnp.allclose( + pot.gradient(x, t=0).value, expected.value + ) # TODO: not .value def test_density(self, pot: BarPotential, x: Vec3) -> None: assert jnp.isclose(pot.density(x, t=0).value, 1.94669274e08) @@ -93,4 +93,4 @@ def test_tidal_tensor(self, pot: AbstractPotentialBase, x: Vec3) -> None: [-0.01038389, 0.01590389, -0.04412159], [-0.02050134, -0.04412159, -0.04753409], ] - assert allclose(pot.tidal_tensor(x, t=0), xp.asarray(expect)) + assert qnp.allclose(pot.tidal_tensor(x, t=0), xp.asarray(expect)) diff --git a/tests/unit/potential/builtin/test_hernquist.py b/tests/unit/potential/builtin/test_hernquist.py index 086a32eb..b1cf47bf 100644 --- a/tests/unit/potential/builtin/test_hernquist.py +++ b/tests/unit/potential/builtin/test_hernquist.py @@ -2,9 +2,9 @@ import jax.numpy as jnp import pytest -from quax import quaxify import quaxed.array_api as xp +import quaxed.numpy as qnp from jax_quantity import Quantity from ..test_core import TestAbstractPotential as AbstractPotential_Test @@ -13,8 +13,6 @@ from galax.potential._potential.base import AbstractPotentialBase from galax.typing import Vec3 -allclose = quaxify(jnp.allclose) - class TestHernquistPotential( AbstractPotential_Test, @@ -39,7 +37,9 @@ def test_gradient(self, pot: HernquistPotential, x: Vec3) -> None: expected = Quantity( [0.05347411, 0.10694822, 0.16042233], pot.units["acceleration"] ) - assert allclose(pot.gradient(x, t=0).value, expected.value) # TODO: not .value + assert qnp.allclose( + pot.gradient(x, t=0).value, expected.value + ) # TODO: not .value def test_density(self, pot: HernquistPotential, x: Vec3) -> None: assert jnp.isclose(pot.density(x, t=0).value, 3.989933e08) @@ -66,4 +66,4 @@ def test_tidal_tensor(self, pot: AbstractPotentialBase, x: Vec3) -> None: [-0.01969533, 0.00656511, -0.05908599], [-0.02954299, -0.05908599, -0.04267321], ] - assert allclose(pot.tidal_tensor(x, t=0), xp.asarray(expect)) + assert qnp.allclose(pot.tidal_tensor(x, t=0), xp.asarray(expect)) diff --git a/tests/unit/potential/builtin/test_isochrone.py b/tests/unit/potential/builtin/test_isochrone.py index 3b53a38c..3c3d44dc 100644 --- a/tests/unit/potential/builtin/test_isochrone.py +++ b/tests/unit/potential/builtin/test_isochrone.py @@ -2,9 +2,9 @@ import jax.numpy as jnp import pytest -from quax import quaxify import quaxed.array_api as xp +import quaxed.numpy as qnp from jax_quantity import Quantity import galax.potential as gp @@ -13,8 +13,6 @@ from galax.potential import AbstractPotentialBase, IsochronePotential from galax.typing import Vec3 -allclose = quaxify(jnp.allclose) - class TestIsochronePotential( AbstractPotential_Test, @@ -39,7 +37,9 @@ def test_gradient(self, pot: IsochronePotential, x: Vec3) -> None: expected = Quantity( [0.04891392, 0.09782784, 0.14674175], pot.units["acceleration"] ) - assert allclose(pot.gradient(x, t=0).value, expected.value) # TODO: not .value + assert qnp.allclose( + pot.gradient(x, t=0).value, expected.value + ) # TODO: not .value def test_density(self, pot: IsochronePotential, x: Vec3) -> None: assert jnp.isclose(pot.density(x, t=0).value, 5.04511665e08) @@ -66,4 +66,4 @@ def test_tidal_tensor(self, pot: AbstractPotentialBase, x: Vec3) -> None: [-0.01688883, 0.00562961, -0.05066648], [-0.02533324, -0.05066648, -0.03659246], ] - assert allclose(pot.tidal_tensor(x, t=0), xp.asarray(expect)) + assert qnp.allclose(pot.tidal_tensor(x, t=0), xp.asarray(expect)) diff --git a/tests/unit/potential/builtin/test_miyamotonagai.py b/tests/unit/potential/builtin/test_miyamotonagai.py index 50037835..baa1b03d 100644 --- a/tests/unit/potential/builtin/test_miyamotonagai.py +++ b/tests/unit/potential/builtin/test_miyamotonagai.py @@ -3,9 +3,9 @@ import astropy.units as u import jax.numpy as jnp import pytest -from quax import quaxify import quaxed.array_api as xp +import quaxed.numpy as qnp from jax_quantity import Quantity import galax.potential as gp @@ -15,8 +15,6 @@ from galax.typing import Vec3 from galax.units import UnitSystem -allclose = quaxify(jnp.allclose) - class TestMiyamotoNagaiPotential( AbstractPotential_Test, @@ -50,7 +48,7 @@ def test_gradient(self, pot: MiyamotoNagaiPotential, x: Vec3) -> None: expected = Quantity( [0.04264751, 0.08529503, 0.16840152], pot.units["acceleration"] ) - assert allclose(pot.gradient(x, t=0).value, expected.value) # TODO: not .value + assert qnp.allclose(pot.gradient(x, t=0).value, expected.value) # TODO: .value def test_density(self, pot: MiyamotoNagaiPotential, x: Vec3) -> None: assert jnp.isclose(pot.density(x, t=0).value, 1.9949418e08) @@ -77,4 +75,4 @@ def test_tidal_tensor(self, pot: AbstractPotentialBase, x: Vec3) -> None: [-0.01146205, 0.0159643, -0.04525999], [-0.02262999, -0.04525999, -0.04912166], ] - assert allclose(pot.tidal_tensor(x, t=0), xp.asarray(expect)) + assert qnp.allclose(pot.tidal_tensor(x, t=0), xp.asarray(expect)) diff --git a/tests/unit/potential/builtin/test_nfw.py b/tests/unit/potential/builtin/test_nfw.py index 1dfdf8fb..51955100 100644 --- a/tests/unit/potential/builtin/test_nfw.py +++ b/tests/unit/potential/builtin/test_nfw.py @@ -4,10 +4,10 @@ import astropy.units as u import jax.numpy as jnp import pytest -from quax import quaxify from typing_extensions import override import quaxed.array_api as xp +import quaxed.numpy as qnp from jax_quantity import Quantity import galax.potential as gp @@ -24,8 +24,6 @@ from galax.units import UnitSystem, galactic from galax.utils._optional_deps import HAS_GALA -allclose = quaxify(jnp.allclose) - class ScaleRadiusParameterMixin(ParameterFieldMixin): """Test the mass parameter.""" @@ -109,7 +107,7 @@ def test_gradient(self, pot: NFWPotential, x: Vec3) -> None: expected = Quantity( [0.0658867, 0.1317734, 0.19766011], pot.units["acceleration"] ) - assert allclose(pot.gradient(x, t=0).value, expected.value) # TODO: not .value + assert qnp.allclose(pot.gradient(x, t=0).value, expected.value) # TODO: .value def test_density(self, pot: NFWPotential, x: Vec3) -> None: assert jnp.isclose(pot.density(x, t=0).value, 9.46039849e08) @@ -136,7 +134,7 @@ def test_tidal_tensor(self, pot: AbstractPotentialBase, x: Vec3) -> None: [-0.02059723, 0.00686574, -0.06179169], [-0.03089585, -0.06179169, -0.04462733], ] - assert allclose(pot.tidal_tensor(x, t=0), xp.asarray(expect)) + assert qnp.allclose(pot.tidal_tensor(x, t=0), xp.asarray(expect)) # ========================================================================== # I/O diff --git a/tests/unit/potential/builtin/test_null.py b/tests/unit/potential/builtin/test_null.py index 7cd9d778..f48ec119 100644 --- a/tests/unit/potential/builtin/test_null.py +++ b/tests/unit/potential/builtin/test_null.py @@ -2,9 +2,9 @@ import jax.numpy as jnp import pytest -from quax import quaxify import quaxed.array_api as xp +import quaxed.numpy as qnp from jax_quantity import Quantity from ..test_core import TestAbstractPotential as AbstractPotential_Test @@ -12,8 +12,6 @@ from galax.typing import Vec3 from galax.units import UnitSystem -allclose = quaxify(jnp.allclose) - class TestNullPotential(AbstractPotential_Test): @pytest.fixture(scope="class") @@ -33,7 +31,7 @@ def test_potential_energy(self, pot: NullPotential, x: Vec3) -> None: def test_gradient(self, pot: NullPotential, x: Vec3) -> None: """Test :meth:`NullPotential.gradient`.""" expected = Quantity([0.0, 0.0, 0.0], pot.units["acceleration"]) - assert allclose(pot.gradient(x, t=0).value, expected.value) # TODO: not .value + assert qnp.allclose(pot.gradient(x, t=0).value, expected.value) # TODO: value def test_density(self, pot: NullPotential, x: Vec3) -> None: """Test :meth:`NullPotential.density`.""" @@ -52,4 +50,4 @@ def test_hessian(self, pot: NullPotential, x: Vec3) -> None: def test_tidal_tensor(self, pot: AbstractPotentialBase, x: Vec3) -> None: """Test the `AbstractPotentialBase.tidal_tensor` method.""" expect = [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]] - assert allclose(pot.tidal_tensor(x, t=0), xp.asarray(expect)) + assert qnp.allclose(pot.tidal_tensor(x, t=0), xp.asarray(expect)) diff --git a/tests/unit/potential/builtin/test_triaxialhernquist.py b/tests/unit/potential/builtin/test_triaxialhernquist.py index 839b3897..a0cbf851 100644 --- a/tests/unit/potential/builtin/test_triaxialhernquist.py +++ b/tests/unit/potential/builtin/test_triaxialhernquist.py @@ -2,9 +2,9 @@ import jax.numpy as jnp import pytest -from quax import quaxify import quaxed.array_api as xp +import quaxed.numpy as qnp from jax_quantity import Quantity from ..test_core import TestAbstractPotential as AbstractPotential_Test @@ -18,8 +18,6 @@ from galax.potential._potential.base import AbstractPotentialBase from galax.typing import Vec3 -allclose = quaxify(jnp.allclose) - class TestTriaxialHernquistPotential( AbstractPotential_Test, @@ -54,7 +52,7 @@ def test_gradient(self, pot: TriaxialHernquistPotential, x: Vec3) -> None: expected = Quantity( [0.01312095, 0.02168751, 0.15745134], pot.units["acceleration"] ) - assert allclose(pot.gradient(x, t=0).value, expected.value) # TODO: not .value + assert qnp.allclose(pot.gradient(x, t=0).value, expected.value) # TODO: value @pytest.mark.xfail(reason="WFF?") def test_density(self, pot: TriaxialHernquistPotential, x: Vec3) -> None: @@ -82,4 +80,4 @@ def test_tidal_tensor(self, pot: AbstractPotentialBase, x: Vec3) -> None: [-0.00146778, 0.02666394, -0.01761339], [-0.0106561, -0.01761339, -0.05714314], ] - assert allclose(pot.tidal_tensor(x, t=0), xp.asarray(expect)) + assert qnp.allclose(pot.tidal_tensor(x, t=0), xp.asarray(expect)) diff --git a/tests/unit/potential/io/test_gala.py b/tests/unit/potential/io/test_gala.py index bbb4d17e..b16c7fc5 100644 --- a/tests/unit/potential/io/test_gala.py +++ b/tests/unit/potential/io/test_gala.py @@ -3,9 +3,9 @@ from inspect import get_annotations from typing import ClassVar -import jax.numpy as jnp import pytest -from quax import quaxify + +import quaxed.numpy as qnp import galax.potential as gp from galax.typing import Vec3 @@ -16,8 +16,6 @@ else: from galax.potential._potential.io.gala_noop import _GALA_TO_GALAX_REGISTRY -array_equal = quaxify(jnp.array_equal) - class GalaIOMixin: """Mixin for testing gala potential I/O. @@ -50,4 +48,4 @@ def test_galax_to_gala_to_galax_roundtrip( rpot = gp.io.gala_to_galax(galax_to_gala(pot)) # quick test that the potential energies are the same - assert array_equal(pot(x, t=0), rpot(x, t=0)) + assert qnp.array_equal(pot(x, t=0), rpot(x, t=0)) diff --git a/tests/unit/potential/test_base.py b/tests/unit/potential/test_base.py index 6e32ed46..c776e2e0 100644 --- a/tests/unit/potential/test_base.py +++ b/tests/unit/potential/test_base.py @@ -5,11 +5,10 @@ import astropy.units as u import equinox as eqx import jax -import jax.numpy as jnp import pytest -from quax import quaxify import quaxed.array_api as xp +import quaxed.numpy as qnp from jax_quantity import Quantity import galax.dynamics as gd @@ -27,8 +26,6 @@ from galax.units import UnitSystem, galactic from galax.utils._jax import vectorize_method -array_equal = quaxify(jnp.array_equal) - class TestAbstractPotentialBase(GalaIOMixin): """Test the `galax.potential.AbstractPotentialBase` class.""" @@ -141,7 +138,7 @@ def test_potential_energy_batch( # Test that the method works on batches. assert pot.potential_energy(batchx, t=0).shape == batchx.shape[:-1] # Test that the batched method is equivalent to the scalar method - assert array_equal( + assert qnp.array_equal( pot.potential_energy(batchx, t=0)[0], pot.potential_energy(batchx[0], t=0) ) @@ -154,7 +151,7 @@ def test_call(self, pot: AbstractPotentialBase, x: Vec3) -> None: def test_gradient(self, pot: AbstractPotentialBase, x: Vec3) -> None: """Test the `AbstractPotentialBase.gradient` method.""" expected = Quantity(xp.ones_like(x), pot.units["acceleration"]) - assert array_equal(pot.gradient(x, t=0), expected) + assert qnp.array_equal(pot.gradient(x, t=0), expected) def test_density(self, pot: AbstractPotentialBase, x: Vec3) -> None: """Test the `AbstractPotentialBase.density` method.""" @@ -162,14 +159,14 @@ def test_density(self, pot: AbstractPotentialBase, x: Vec3) -> None: def test_hessian(self, pot: AbstractPotentialBase, x: Vec3) -> None: """Test the `AbstractPotentialBase.hessian` method.""" - assert array_equal( + assert qnp.array_equal( pot.hessian(x, t=0), xp.asarray([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]), ) def test_acceleration(self, pot: AbstractPotentialBase, x: Vec3) -> None: """Test the `AbstractPotentialBase.acceleration` method.""" - assert array_equal(pot.acceleration(x, t=0), -pot.gradient(x, t=0)) + assert qnp.array_equal(pot.acceleration(x, t=0), -pot.gradient(x, t=0)) # --------------------------------- # Convenience methods @@ -177,7 +174,7 @@ def test_acceleration(self, pot: AbstractPotentialBase, x: Vec3) -> None: def test_tidal_tensor(self, pot: AbstractPotentialBase, x: Vec3) -> None: """Test the `AbstractPotentialBase.tidal_tensor` method.""" expect = [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]] - assert array_equal(pot.tidal_tensor(x, t=0), expect) + assert qnp.array_equal(pot.tidal_tensor(x, t=0), expect) # ========================================================================= @@ -188,7 +185,7 @@ def test_integrate_orbit(self, pot: AbstractPotentialBase, xv: Vec6) -> None: orbit = pot.integrate_orbit(xv, ts) assert isinstance(orbit, gd.Orbit) assert orbit.shape == (len(ts.value),) # TODO: don't use .value - assert array_equal(orbit.t, ts) + assert qnp.array_equal(orbit.t, ts) def test_integrate_orbit_batch(self, pot: AbstractPotentialBase, xv: Vec6) -> None: """Test the `AbstractPotentialBase.integrate_orbit` method.""" @@ -198,11 +195,11 @@ def test_integrate_orbit_batch(self, pot: AbstractPotentialBase, xv: Vec6) -> No orbits = pot.integrate_orbit(xv[None, :], ts) assert isinstance(orbits, gd.Orbit) assert orbits.shape == (1, len(ts)) - assert array_equal(orbits.t, ts) + assert qnp.array_equal(orbits.t, ts) # More complicated batch xv2 = xp.stack([xv, xv], axis=0) orbits = pot.integrate_orbit(xv2, ts) assert isinstance(orbits, gd.Orbit) assert orbits.shape == (2, len(ts)) - assert array_equal(orbits.t, ts) + assert qnp.array_equal(orbits.t, ts) diff --git a/tests/unit/potential/test_composite.py b/tests/unit/potential/test_composite.py index e6343e77..7ef2e36b 100644 --- a/tests/unit/potential/test_composite.py +++ b/tests/unit/potential/test_composite.py @@ -6,10 +6,10 @@ import jax.numpy as jnp import pytest from plum import NotFoundLookupError -from quax import quaxify from typing_extensions import override import quaxed.array_api as xp +import quaxed.numpy as qnp from jax_quantity import Quantity from .test_base import TestAbstractPotentialBase as AbstractPotentialBase_Test @@ -25,9 +25,6 @@ from galax.units import UnitSystem, dimensionless, galactic, solarsystem from galax.utils._misc import first -array_equal = quaxify(jnp.array_equal) -allclose = quaxify(jnp.allclose) - # TODO: write the base-class test class AbstractCompositePotential_Test(AbstractPotentialBase_Test, FieldUnitSystemMixin): @@ -269,7 +266,9 @@ def test_gradient(self, pot: CompositePotential, x: Vec3) -> None: expected = Quantity( [0.01124388, 0.02248775, 0.03382281], pot.units["acceleration"] ) - assert allclose(pot.gradient(x, t=0).value, expected.value) # TODO: not .value + assert qnp.allclose( + pot.gradient(x, t=0).value, expected.value + ) # TODO: not .value def test_density(self, pot: CompositePotential, x: Vec3) -> None: assert jnp.isclose(pot.density(x, t=0).value, 2.7958598e08) @@ -296,4 +295,4 @@ def test_tidal_tensor(self, pot: AbstractPotentialBase, x: Vec3) -> None: [-0.0025614, 0.00085275, -0.00768793], [-0.00384397, -0.00768793, -0.00554761], ] - assert allclose(pot.tidal_tensor(x, t=0), xp.asarray(expect)) + assert qnp.allclose(pot.tidal_tensor(x, t=0), xp.asarray(expect)) diff --git a/tests/unit/potential/test_frame.py b/tests/unit/potential/test_frame.py index 24e1b014..fe1e6f76 100644 --- a/tests/unit/potential/test_frame.py +++ b/tests/unit/potential/test_frame.py @@ -2,17 +2,13 @@ from dataclasses import replace -import jax.numpy as jnp -from quax import quaxify - import quaxed.array_api as xp +import quaxed.numpy as qnp from jax_quantity import Quantity import galax.coordinates.operators as gco import galax.potential as gp -array_equal = quaxify(jnp.array_equal) - def test_bar_means_of_rotation() -> None: """Test the equivalence of hard-coded vs operator means of rotation.""" @@ -39,19 +35,19 @@ def test_bar_means_of_rotation() -> None: q = Quantity([5, 0, 0], "kpc") t = Quantity(0, "Myr") assert framedpot.potential_energy(q, t) == hardpot.potential_energy(q, t) - assert array_equal(framedpot.acceleration(q, t), hardpot.acceleration(q, t)) + assert qnp.array_equal(framedpot.acceleration(q, t), hardpot.acceleration(q, t)) # They should be equivalent at t=110 Myr (1/2 period) t = Quantity(110, "Myr") assert framedpot.potential_energy(q, t) == hardpot.potential_energy(q, t) - assert array_equal(framedpot.acceleration(q, t), hardpot.acceleration(q, t)) + assert qnp.array_equal(framedpot.acceleration(q, t), hardpot.acceleration(q, t)) # They should be equivalent at t=220 Myr (1 period) t = Quantity(220, "Myr") assert framedpot.potential_energy(q, t) == hardpot.potential_energy(q, t) - assert array_equal(framedpot.acceleration(q, t), hardpot.acceleration(q, t)) + assert qnp.array_equal(framedpot.acceleration(q, t), hardpot.acceleration(q, t)) # They should be equivalent at t=55 Myr (1/4 period) t = Quantity(55, "Myr") assert framedpot.potential_energy(q, t) == hardpot.potential_energy(q, t) - assert array_equal(framedpot.acceleration(q, t), hardpot.acceleration(q, t)) + assert qnp.array_equal(framedpot.acceleration(q, t), hardpot.acceleration(q, t)) diff --git a/tests/unit/potential/test_special.py b/tests/unit/potential/test_special.py index 32660d25..e2e5ed34 100644 --- a/tests/unit/potential/test_special.py +++ b/tests/unit/potential/test_special.py @@ -5,10 +5,10 @@ import jax.numpy as jnp import pytest from plum import NotFoundLookupError -from quax import quaxify from typing_extensions import override import quaxed.array_api as xp +import quaxed.numpy as qnp from jax_quantity import Quantity from .test_composite import AbstractCompositePotential_Test @@ -22,9 +22,6 @@ from galax.units import UnitSystem, dimensionless, galactic, solarsystem from galax.utils._misc import first -allclose = quaxify(jnp.allclose) - - ############################################################################## @@ -256,7 +253,9 @@ def test_gradient(self, pot: MilkyWayPotential, x: Vec3) -> None: expected = Quantity( [0.00256403, 0.00512806, 0.01115272], pot.units["acceleration"] ) - assert allclose(pot.gradient(x, t=0).value, expected.value) # TODO: not .value + assert qnp.allclose( + pot.gradient(x, t=0).value, expected.value + ) # TODO: not .value def test_density(self, pot: MilkyWayPotential, x: Vec3) -> None: """Test the :meth:`MilkyWayPotential.density` method.""" @@ -264,7 +263,7 @@ def test_density(self, pot: MilkyWayPotential, x: Vec3) -> None: def test_hessian(self, pot: MilkyWayPotential, x: Vec3) -> None: """Test the :meth:`MilkyWayPotential.hessian` method.""" - assert allclose( + assert qnp.allclose( pot.hessian(x, t=0), xp.asarray( [ @@ -285,4 +284,4 @@ def test_tidal_tensor(self, pot: AbstractPotentialBase, x: Vec3) -> None: [-0.00050698, 0.00092134, -0.00202546], [-0.00101273, -0.00202546, -0.00260316], ] - assert allclose(pot.tidal_tensor(x, t=0), xp.asarray(expect)) + assert qnp.allclose(pot.tidal_tensor(x, t=0), xp.asarray(expect)) diff --git a/tests/unit/utils/test_shape.py b/tests/unit/utils/test_shape.py index 3a8dd479..f75091c7 100644 --- a/tests/unit/utils/test_shape.py +++ b/tests/unit/utils/test_shape.py @@ -6,10 +6,9 @@ import jax import jax.numpy as jnp import pytest -from jax.numpy import array_equal -from quax import quaxify import quaxed.array_api as xp +import quaxed.numpy as qnp from jax_quantity import Quantity from galax.utils._shape import ( @@ -19,8 +18,6 @@ expand_batch_dims, ) -array_equal = quaxify(array_equal) - class TestAtleastBatched: """Test the `atleast_batched` function.""" @@ -37,8 +34,8 @@ def test_atleast_batched_example(self) -> None: """Test the `atleast_batched` function with an example.""" x = xp.asarray([1, 2, 3]) # `atleast_batched` versus `atleast_2d` - assert array_equal(atleast_batched(x), x[:, None]) - assert array_equal(jnp.atleast_2d(x), x[None, :]) + assert qnp.array_equal(atleast_batched(x), x[:, None]) + assert qnp.array_equal(jnp.atleast_2d(x), x[None, :]) @pytest.mark.parametrize( ("x", "expect"), @@ -60,7 +57,7 @@ def test_atleast_batched_example(self) -> None: def test_atleast_batched_one_arg(self, x: Any, expect: Any) -> None: """Test the `atleast_batched` function with one argument.""" got = atleast_batched(xp.asarray(x)) - assert array_equal(got, xp.asarray(expect)) + assert qnp.array_equal(got, xp.asarray(expect)) assert got.ndim >= 2 def test_atleast_batched_multiple_args(self) -> None: @@ -71,8 +68,8 @@ def test_atleast_batched_multiple_args(self) -> None: result = atleast_batched(x, y) assert isinstance(result, tuple) assert len(result) == 2 - assert array_equal(result[0], x[:, None]) - assert array_equal(result[1], y[:, None]) + assert qnp.array_equal(result[0], x[:, None]) + assert qnp.array_equal(result[1], y[:, None]) # Quantity x = Quantity(x, "m") @@ -82,8 +79,8 @@ def test_atleast_batched_multiple_args(self) -> None: assert len(result) == 2 assert isinstance(result[0], Quantity) assert isinstance(result[1], Quantity) - assert array_equal(result[0], Quantity(x.value[:, None], "m")) - assert array_equal(result[1], Quantity(y.value[:, None], "m")) + assert qnp.array_equal(result[0], Quantity(x.value[:, None], "m")) + assert qnp.array_equal(result[1], Quantity(y.value[:, None], "m")) class TestBatchedShape: @@ -147,7 +144,7 @@ def test_expand_batch_dims( ) -> None: """Test :func:`galax.utils._shape.expand_batch_dims`.""" got = expand_batch_dims(arr, ndim=ndim) - assert array_equal(got, expect) + assert qnp.array_equal(got, expect) assert got.shape == expect.shape @@ -180,5 +177,5 @@ def test_expand_arr_dims( ) -> None: """Test :func:`galax.utils._shape.expand_arr_dims`.""" got = expand_arr_dims(arr, ndim=ndim) - assert array_equal(got, expect) + assert qnp.array_equal(got, expect) assert got.shape == expect.shape