Skip to content

Commit

Permalink
feat(deps): bump quaxed (GalacticDynamics#201)
Browse files Browse the repository at this point in the history
* feat(deps): use extended quaxed

Signed-off-by: nstarman <nstarman@users.noreply.github.com>
  • Loading branch information
nstarman committed Mar 10, 2024
1 parent 9a8e889 commit 425aeed
Show file tree
Hide file tree
Showing 16 changed files with 68 additions and 93 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
"Typing :: Typed",
]
dependencies = [
"quaxed >=0.2",
"quaxed >=0.3",
"astropy >= 5.3",
"beartype",
"diffrax",
Expand Down
4 changes: 2 additions & 2 deletions src/galax/coordinates/_psp/operator_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@

import jax.numpy as jnp
from plum import convert
from quax import quaxify

import quaxed.numpy as qnp
from coordinax import CartesianDifferential3D
from coordinax.operators import (
AbstractCompositeOperator,
Expand Down Expand Up @@ -301,7 +301,7 @@ def call(
return (replace(psp, q=q, p=p), t)


vec_matmul = quaxify(jnp.vectorize(jnp.matmul, signature="(3,3),(3)->(3)"))
vec_matmul = qnp.vectorize(jnp.matmul, signature="(3,3),(3)->(3)")


@op_call_dispatch
Expand Down
3 changes: 0 additions & 3 deletions tests/unit/dynamics/test_orbit.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import jax.random as jr
import pytest
from plum import convert
from quax import quaxify

import quaxed.array_api as xp
from jax_quantity import Quantity
Expand All @@ -16,8 +15,6 @@
from galax.potential import AbstractPotentialBase, MilkyWayPotential
from galax.units import galactic

array_equal = quaxify(jnp.array_equal)


class TestOrbit(AbstractPhaseSpaceTimePosition_Test[Orbit]):
"""Test :class:`~galax.coordinates.PhaseSpacePosition`."""
Expand Down
10 changes: 5 additions & 5 deletions tests/unit/potential/builtin/test_bar.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
import astropy.units as u
import jax.numpy as jnp
import pytest
from quax import quaxify

import quaxed.array_api as xp
import quaxed.numpy as qnp
from jax_quantity import Quantity

from ..test_core import TestAbstractPotential as AbstractPotential_Test
Expand All @@ -19,8 +19,6 @@
from galax.typing import Vec3
from galax.units import UnitSystem

allclose = quaxify(jnp.allclose)


class TestBarPotential(
AbstractPotential_Test,
Expand Down Expand Up @@ -66,7 +64,9 @@ def test_gradient(self, pot: BarPotential, x: Vec3) -> None:
expected = Quantity(
[0.04011905, 0.08383918, 0.16552719], pot.units["acceleration"]
)
assert allclose(pot.gradient(x, t=0).value, expected.value) # TODO: not .value
assert qnp.allclose(
pot.gradient(x, t=0).value, expected.value
) # TODO: not .value

def test_density(self, pot: BarPotential, x: Vec3) -> None:
assert jnp.isclose(pot.density(x, t=0).value, 1.94669274e08)
Expand All @@ -93,4 +93,4 @@ def test_tidal_tensor(self, pot: AbstractPotentialBase, x: Vec3) -> None:
[-0.01038389, 0.01590389, -0.04412159],
[-0.02050134, -0.04412159, -0.04753409],
]
assert allclose(pot.tidal_tensor(x, t=0), xp.asarray(expect))
assert qnp.allclose(pot.tidal_tensor(x, t=0), xp.asarray(expect))
10 changes: 5 additions & 5 deletions tests/unit/potential/builtin/test_hernquist.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@

import jax.numpy as jnp
import pytest
from quax import quaxify

import quaxed.array_api as xp
import quaxed.numpy as qnp
from jax_quantity import Quantity

from ..test_core import TestAbstractPotential as AbstractPotential_Test
Expand All @@ -13,8 +13,6 @@
from galax.potential._potential.base import AbstractPotentialBase
from galax.typing import Vec3

allclose = quaxify(jnp.allclose)


class TestHernquistPotential(
AbstractPotential_Test,
Expand All @@ -39,7 +37,9 @@ def test_gradient(self, pot: HernquistPotential, x: Vec3) -> None:
expected = Quantity(
[0.05347411, 0.10694822, 0.16042233], pot.units["acceleration"]
)
assert allclose(pot.gradient(x, t=0).value, expected.value) # TODO: not .value
assert qnp.allclose(
pot.gradient(x, t=0).value, expected.value
) # TODO: not .value

def test_density(self, pot: HernquistPotential, x: Vec3) -> None:
assert jnp.isclose(pot.density(x, t=0).value, 3.989933e08)
Expand All @@ -66,4 +66,4 @@ def test_tidal_tensor(self, pot: AbstractPotentialBase, x: Vec3) -> None:
[-0.01969533, 0.00656511, -0.05908599],
[-0.02954299, -0.05908599, -0.04267321],
]
assert allclose(pot.tidal_tensor(x, t=0), xp.asarray(expect))
assert qnp.allclose(pot.tidal_tensor(x, t=0), xp.asarray(expect))
10 changes: 5 additions & 5 deletions tests/unit/potential/builtin/test_isochrone.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@

import jax.numpy as jnp
import pytest
from quax import quaxify

import quaxed.array_api as xp
import quaxed.numpy as qnp
from jax_quantity import Quantity

import galax.potential as gp
Expand All @@ -13,8 +13,6 @@
from galax.potential import AbstractPotentialBase, IsochronePotential
from galax.typing import Vec3

allclose = quaxify(jnp.allclose)


class TestIsochronePotential(
AbstractPotential_Test,
Expand All @@ -39,7 +37,9 @@ def test_gradient(self, pot: IsochronePotential, x: Vec3) -> None:
expected = Quantity(
[0.04891392, 0.09782784, 0.14674175], pot.units["acceleration"]
)
assert allclose(pot.gradient(x, t=0).value, expected.value) # TODO: not .value
assert qnp.allclose(
pot.gradient(x, t=0).value, expected.value
) # TODO: not .value

def test_density(self, pot: IsochronePotential, x: Vec3) -> None:
assert jnp.isclose(pot.density(x, t=0).value, 5.04511665e08)
Expand All @@ -66,4 +66,4 @@ def test_tidal_tensor(self, pot: AbstractPotentialBase, x: Vec3) -> None:
[-0.01688883, 0.00562961, -0.05066648],
[-0.02533324, -0.05066648, -0.03659246],
]
assert allclose(pot.tidal_tensor(x, t=0), xp.asarray(expect))
assert qnp.allclose(pot.tidal_tensor(x, t=0), xp.asarray(expect))
8 changes: 3 additions & 5 deletions tests/unit/potential/builtin/test_miyamotonagai.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
import astropy.units as u
import jax.numpy as jnp
import pytest
from quax import quaxify

import quaxed.array_api as xp
import quaxed.numpy as qnp
from jax_quantity import Quantity

import galax.potential as gp
Expand All @@ -15,8 +15,6 @@
from galax.typing import Vec3
from galax.units import UnitSystem

allclose = quaxify(jnp.allclose)


class TestMiyamotoNagaiPotential(
AbstractPotential_Test,
Expand Down Expand Up @@ -50,7 +48,7 @@ def test_gradient(self, pot: MiyamotoNagaiPotential, x: Vec3) -> None:
expected = Quantity(
[0.04264751, 0.08529503, 0.16840152], pot.units["acceleration"]
)
assert allclose(pot.gradient(x, t=0).value, expected.value) # TODO: not .value
assert qnp.allclose(pot.gradient(x, t=0).value, expected.value) # TODO: .value

def test_density(self, pot: MiyamotoNagaiPotential, x: Vec3) -> None:
assert jnp.isclose(pot.density(x, t=0).value, 1.9949418e08)
Expand All @@ -77,4 +75,4 @@ def test_tidal_tensor(self, pot: AbstractPotentialBase, x: Vec3) -> None:
[-0.01146205, 0.0159643, -0.04525999],
[-0.02262999, -0.04525999, -0.04912166],
]
assert allclose(pot.tidal_tensor(x, t=0), xp.asarray(expect))
assert qnp.allclose(pot.tidal_tensor(x, t=0), xp.asarray(expect))
8 changes: 3 additions & 5 deletions tests/unit/potential/builtin/test_nfw.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
import astropy.units as u
import jax.numpy as jnp
import pytest
from quax import quaxify
from typing_extensions import override

import quaxed.array_api as xp
import quaxed.numpy as qnp
from jax_quantity import Quantity

import galax.potential as gp
Expand All @@ -24,8 +24,6 @@
from galax.units import UnitSystem, galactic
from galax.utils._optional_deps import HAS_GALA

allclose = quaxify(jnp.allclose)


class ScaleRadiusParameterMixin(ParameterFieldMixin):
"""Test the mass parameter."""
Expand Down Expand Up @@ -109,7 +107,7 @@ def test_gradient(self, pot: NFWPotential, x: Vec3) -> None:
expected = Quantity(
[0.0658867, 0.1317734, 0.19766011], pot.units["acceleration"]
)
assert allclose(pot.gradient(x, t=0).value, expected.value) # TODO: not .value
assert qnp.allclose(pot.gradient(x, t=0).value, expected.value) # TODO: .value

def test_density(self, pot: NFWPotential, x: Vec3) -> None:
assert jnp.isclose(pot.density(x, t=0).value, 9.46039849e08)
Expand All @@ -136,7 +134,7 @@ def test_tidal_tensor(self, pot: AbstractPotentialBase, x: Vec3) -> None:
[-0.02059723, 0.00686574, -0.06179169],
[-0.03089585, -0.06179169, -0.04462733],
]
assert allclose(pot.tidal_tensor(x, t=0), xp.asarray(expect))
assert qnp.allclose(pot.tidal_tensor(x, t=0), xp.asarray(expect))

# ==========================================================================
# I/O
Expand Down
8 changes: 3 additions & 5 deletions tests/unit/potential/builtin/test_null.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,16 @@

import jax.numpy as jnp
import pytest
from quax import quaxify

import quaxed.array_api as xp
import quaxed.numpy as qnp
from jax_quantity import Quantity

from ..test_core import TestAbstractPotential as AbstractPotential_Test
from galax.potential import AbstractPotentialBase, NullPotential
from galax.typing import Vec3
from galax.units import UnitSystem

allclose = quaxify(jnp.allclose)


class TestNullPotential(AbstractPotential_Test):
@pytest.fixture(scope="class")
Expand All @@ -33,7 +31,7 @@ def test_potential_energy(self, pot: NullPotential, x: Vec3) -> None:
def test_gradient(self, pot: NullPotential, x: Vec3) -> None:
"""Test :meth:`NullPotential.gradient`."""
expected = Quantity([0.0, 0.0, 0.0], pot.units["acceleration"])
assert allclose(pot.gradient(x, t=0).value, expected.value) # TODO: not .value
assert qnp.allclose(pot.gradient(x, t=0).value, expected.value) # TODO: value

def test_density(self, pot: NullPotential, x: Vec3) -> None:
"""Test :meth:`NullPotential.density`."""
Expand All @@ -52,4 +50,4 @@ def test_hessian(self, pot: NullPotential, x: Vec3) -> None:
def test_tidal_tensor(self, pot: AbstractPotentialBase, x: Vec3) -> None:
"""Test the `AbstractPotentialBase.tidal_tensor` method."""
expect = [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]
assert allclose(pot.tidal_tensor(x, t=0), xp.asarray(expect))
assert qnp.allclose(pot.tidal_tensor(x, t=0), xp.asarray(expect))
8 changes: 3 additions & 5 deletions tests/unit/potential/builtin/test_triaxialhernquist.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@

import jax.numpy as jnp
import pytest
from quax import quaxify

import quaxed.array_api as xp
import quaxed.numpy as qnp
from jax_quantity import Quantity

from ..test_core import TestAbstractPotential as AbstractPotential_Test
Expand All @@ -18,8 +18,6 @@
from galax.potential._potential.base import AbstractPotentialBase
from galax.typing import Vec3

allclose = quaxify(jnp.allclose)


class TestTriaxialHernquistPotential(
AbstractPotential_Test,
Expand Down Expand Up @@ -54,7 +52,7 @@ def test_gradient(self, pot: TriaxialHernquistPotential, x: Vec3) -> None:
expected = Quantity(
[0.01312095, 0.02168751, 0.15745134], pot.units["acceleration"]
)
assert allclose(pot.gradient(x, t=0).value, expected.value) # TODO: not .value
assert qnp.allclose(pot.gradient(x, t=0).value, expected.value) # TODO: value

@pytest.mark.xfail(reason="WFF?")
def test_density(self, pot: TriaxialHernquistPotential, x: Vec3) -> None:
Expand Down Expand Up @@ -82,4 +80,4 @@ def test_tidal_tensor(self, pot: AbstractPotentialBase, x: Vec3) -> None:
[-0.00146778, 0.02666394, -0.01761339],
[-0.0106561, -0.01761339, -0.05714314],
]
assert allclose(pot.tidal_tensor(x, t=0), xp.asarray(expect))
assert qnp.allclose(pot.tidal_tensor(x, t=0), xp.asarray(expect))
8 changes: 3 additions & 5 deletions tests/unit/potential/io/test_gala.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
from inspect import get_annotations
from typing import ClassVar

import jax.numpy as jnp
import pytest
from quax import quaxify

import quaxed.numpy as qnp

import galax.potential as gp
from galax.typing import Vec3
Expand All @@ -16,8 +16,6 @@
else:
from galax.potential._potential.io.gala_noop import _GALA_TO_GALAX_REGISTRY

array_equal = quaxify(jnp.array_equal)


class GalaIOMixin:
"""Mixin for testing gala potential I/O.
Expand Down Expand Up @@ -50,4 +48,4 @@ def test_galax_to_gala_to_galax_roundtrip(
rpot = gp.io.gala_to_galax(galax_to_gala(pot))

# quick test that the potential energies are the same
assert array_equal(pot(x, t=0), rpot(x, t=0))
assert qnp.array_equal(pot(x, t=0), rpot(x, t=0))
Loading

0 comments on commit 425aeed

Please sign in to comment.