From a1b421e66593e8ce406d0fb21ef009ecf6fbca54 Mon Sep 17 00:00:00 2001 From: Chaoming Wang Date: Mon, 8 Jul 2024 10:47:46 +0100 Subject: [PATCH] fix `arange` and others --- brainunit/_unit_test.py | 20 +++--- brainunit/math/_fun_array_creation.py | 94 ++++++--------------------- 2 files changed, 30 insertions(+), 84 deletions(-) diff --git a/brainunit/_unit_test.py b/brainunit/_unit_test.py index 795b8e3..c7401a4 100644 --- a/brainunit/_unit_test.py +++ b/brainunit/_unit_test.py @@ -608,16 +608,18 @@ def test_addition_subtraction(): with pytest.raises(DimensionMismatchError): np.array([5], dtype=np.float64) - q + # Check that operations with 0 work - assert_quantity(q + 0, q.value, volt) - assert_quantity(0 + q, q.value, volt) - assert_quantity(q - 0, q.value, volt) - # Doesn't support 0 - Quantity - # assert_quantity(0 - q, -q.value, volt) - assert_quantity(q + np.float64(0), q.value, volt) - assert_quantity(np.float64(0) + q, q.value, volt) - assert_quantity(q - np.float64(0), q.value, volt) - # assert_quantity(np.float64(0) - q, -q.value, volt) + with pytest.raises(DimensionMismatchError): + assert_quantity(q + 0, q.value, volt) + assert_quantity(0 + q, q.value, volt) + assert_quantity(q - 0, q.value, volt) + # Doesn't support 0 - Quantity + # assert_quantity(0 - q, -q.value, volt) + assert_quantity(q + np.float64(0), q.value, volt) + assert_quantity(np.float64(0) + q, q.value, volt) + assert_quantity(q - np.float64(0), q.value, volt) + # assert_quantity(np.float64(0) - q, -q.value, volt) # # using unsupported objects should fail # with pytest.raises(TypeError): diff --git a/brainunit/math/_fun_array_creation.py b/brainunit/math/_fun_array_creation.py index 685b8c7..74be6e9 100644 --- a/brainunit/math/_fun_array_creation.py +++ b/brainunit/math/_fun_array_creation.py @@ -26,7 +26,7 @@ Quantity, Unit, fail_for_dimension_mismatch, - is_unitless, DimensionMismatchError, ) + DimensionMismatchError, ) from .._misc import set_module_as Shape = Union[int, Sequence[int]] @@ -649,81 +649,25 @@ def arange( out : quantity or array Array of evenly spaced values. """ - - arg_len = len([x for x in [start, stop, step] if x is not None]) - - if arg_len == 1: - if stop is not None: - raise TypeError("Duplicate definition of 'stop'") - stop = start - start = 0 - elif arg_len == 2: - if start is not None and stop is None: - stop = start - start = 0 - - elif arg_len > 3: - raise TypeError("Need between 1 and 3 non-keyword arguments") - - # default values - if start is None: - start = 0 - if step is None: - step = 1 - - if stop is None: - raise TypeError("Missing stop argument.") - if stop is not None and not is_unitless(stop): - start = Quantity(start, dim=stop.dim) - - fail_for_dimension_mismatch( - start, - stop, - error_message="Start value {start} and stop value {stop} have to have the same units.", - start=start, - stop=stop, - ) - fail_for_dimension_mismatch( - stop, - step, - error_message="Stop value {stop} and step value {step} have to have the same units.", - stop=stop, - step=step, - ) - - unit = getattr(stop, "dim", DIMENSIONLESS) - - if start == 0: - return Quantity( - jnp.arange( - start=start.value if isinstance(start, Quantity) else jnp.asarray(start), - stop=stop.value if isinstance(stop, Quantity) else jnp.asarray(stop), - step=step.value if isinstance(step, Quantity) else jnp.asarray(step), - dtype=dtype, - ), - dim=unit, - ) if unit != DIMENSIONLESS else jnp.arange( - start=start.value if isinstance(start, Quantity) else jnp.asarray(start), - stop=stop.value if isinstance(stop, Quantity) else jnp.asarray(stop), - step=step.value if isinstance(step, Quantity) else jnp.asarray(step), - dtype=dtype, + non_none_data = [d for d in (start, stop, step) if d is not None] + # checking the dimension of the data + assert len(non_none_data) > 0, 'At least one of start, stop, or step must be provided.' + d1 = non_none_data[0] + for d2 in non_none_data[1:]: + fail_for_dimension_mismatch( + d1, d2, + error_message="Start stop, and step value have to " + "have the same units. Got: {d1} {d2}", + d1=d1, d2=d2 ) - else: - return Quantity( - jnp.arange( - start.value if isinstance(start, Quantity) else jnp.asarray(start), - stop=stop.value if isinstance(stop, Quantity) else jnp.asarray(stop), - step=step.value if isinstance(step, Quantity) else jnp.asarray(step), - dtype=dtype, - ), - dim=unit, - ) if unit != DIMENSIONLESS else jnp.arange( - start.value if isinstance(start, Quantity) else jnp.asarray(start), - stop=stop.value if isinstance(stop, Quantity) else jnp.asarray(stop), - step=step.value if isinstance(step, Quantity) else jnp.asarray(step), - dtype=dtype, - ) - + dim = d1.dim if isinstance(d1, Quantity) else None + # convert to array + start = start.value if isinstance(start, Quantity) else start + stop = stop.value if isinstance(stop, Quantity) else stop + step = step.value if isinstance(step, Quantity) else step + # compute + r = jnp.arange(start, stop, step, dtype=dtype) + return Quantity(r, dim=dim) if dim is not None else r @set_module_as('brainunit.math')