diff --git a/src/galax/potential/_potential/base.py b/src/galax/potential/_potential/base.py index 45dc5db7..f31d657c 100644 --- a/src/galax/potential/_potential/base.py +++ b/src/galax/potential/_potential/base.py @@ -161,17 +161,14 @@ def potential( return potential(self, *args, **kwargs) @partial(jax.jit) - def __call__( - self, q: gt.LengthBatchVec3, /, t: gt.BatchableRealQScalar - ) -> Float[Quantity["specific energy"], "*batch"]: + def __call__(self, *args: Any) -> Float[Quantity["specific energy"], "*batch"]: """Compute the potential energy at the given position(s). Parameters ---------- - q : Quantity[float, (*batch, 3), 'length'] - The position to compute the value of the potential. - t : Array[float | int, *batch] | float | int - The time at which to compute the value of the potential. + *args : Any + Arguments to pass to the potential method. + See :func:`~galax.potential.potential`. Returns ------- @@ -180,9 +177,10 @@ def __call__( See Also -------- + :func:`galax.potential.potential` :meth:`galax.potential.AbstractPotentialBase.potential` """ - return self.potential(q, t) + return self.potential(*args) # --------------------------------------- # Gradient diff --git a/tests/unit/potential/io/test_gala.py b/tests/unit/potential/io/test_gala.py index 153a1a0f..ff4ba32e 100644 --- a/tests/unit/potential/io/test_gala.py +++ b/tests/unit/potential/io/test_gala.py @@ -53,8 +53,8 @@ def test_galax_to_gala_to_galax_roundtrip( rpot = gp.io.gala_to_galax(galax_to_gala(pot)) # quick test that the potential energies are the same - got = rpot(x, t=0) - exp = pot(x, t=0) + got = rpot(x, 0) + exp = pot(x, 0) assert qnp.allclose(got, exp, atol=Quantity(1e-14, exp.unit)) # TODO: add more robust tests diff --git a/tests/unit/potential/test_base.py b/tests/unit/potential/test_base.py index c0c69942..b4e69f88 100644 --- a/tests/unit/potential/test_base.py +++ b/tests/unit/potential/test_base.py @@ -134,7 +134,7 @@ def test_potential_batch( def test_call(self, pot: AbstractPotentialBase, x: gt.QVec3) -> None: """Test the `AbstractPotentialBase.__call__` method.""" - assert xp.equal(pot(x, t=0), pot.potential(x, t=0)) + assert xp.equal(pot(x, 0), pot.potential(x, 0)) @abstractmethod def test_gradient(self, pot: AbstractPotentialBase, x: gt.QVec3) -> None: