diff --git a/src/galax/utils/dataclasses.py b/src/galax/utils/dataclasses.py index 18f11fab..95c04800 100644 --- a/src/galax/utils/dataclasses.py +++ b/src/galax/utils/dataclasses.py @@ -21,15 +21,9 @@ ) import astropy.units as u -import jax.numpy as jnp from equinox._module import _has_dataclass_init, _ModuleMeta -from jax.dtypes import canonicalize_dtype -from jaxtyping import Array, Float from typing_extensions import ParamSpec, Unpack -import quaxed.array_api as xp -from unxt import Quantity - import galax.typing as gt if TYPE_CHECKING: @@ -264,24 +258,6 @@ def __new__( # noqa: D102 # pylint: disable=signature-differs return cls -############################################################################## -# Converters - - -@ft.singledispatch -def converter_float_array(x: Any, /) -> Float[Array, "*shape"]: - """Convert to a batched vector.""" - x = xp.asarray(x, dtype=None) - dtype = jnp.promote_types(x.dtype, canonicalize_dtype(float)) - return xp.asarray(x, dtype=dtype) - - -@converter_float_array.register -def _converter_float_quantity(x: Quantity, /) -> Float[Array, "*shape"]: - """Convert to a batched vector.""" - return converter_float_array(x.to_units_value(u.dimensionless_unscaled)) - - ############################################################################## # Utils diff --git a/tests/unit/utils/test_collections.py b/tests/unit/utils/test_collections.py index 427bb8e5..1a40a2f7 100644 --- a/tests/unit/utils/test_collections.py +++ b/tests/unit/utils/test_collections.py @@ -86,6 +86,10 @@ def test_or(self, d: ImmutableDict[str, Any]) -> None: assert d | OrderedDict([("c", 3)]) == ImmutableDict(a=1, b=2, c=3) assert d | MappingProxyType({"c": 3}) == ImmutableDict(a=1, b=2, c=3) + # Should raise TypeError if not a mapping. + with pytest.raises(TypeError, match="unsupported operand type"): + _ = d | 1 + def test_ror(self, d: ImmutableDict[str, Any]) -> None: """Test `__ror__`.""" # Reverse order