Skip to content

Commit

Permalink
fix arange and others
Browse files Browse the repository at this point in the history
  • Loading branch information
chaoming0625 committed Jul 8, 2024
1 parent 51491f8 commit a1b421e
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 84 deletions.
20 changes: 11 additions & 9 deletions brainunit/_unit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -608,16 +608,18 @@ def test_addition_subtraction():
with pytest.raises(DimensionMismatchError):
np.array([5], dtype=np.float64) - q


# Check that operations with 0 work
assert_quantity(q + 0, q.value, volt)
assert_quantity(0 + q, q.value, volt)
assert_quantity(q - 0, q.value, volt)
# Doesn't support 0 - Quantity
# assert_quantity(0 - q, -q.value, volt)
assert_quantity(q + np.float64(0), q.value, volt)
assert_quantity(np.float64(0) + q, q.value, volt)
assert_quantity(q - np.float64(0), q.value, volt)
# assert_quantity(np.float64(0) - q, -q.value, volt)
with pytest.raises(DimensionMismatchError):
assert_quantity(q + 0, q.value, volt)
assert_quantity(0 + q, q.value, volt)
assert_quantity(q - 0, q.value, volt)
# Doesn't support 0 - Quantity
# assert_quantity(0 - q, -q.value, volt)
assert_quantity(q + np.float64(0), q.value, volt)
assert_quantity(np.float64(0) + q, q.value, volt)
assert_quantity(q - np.float64(0), q.value, volt)
# assert_quantity(np.float64(0) - q, -q.value, volt)

# # using unsupported objects should fail
# with pytest.raises(TypeError):
Expand Down
94 changes: 19 additions & 75 deletions brainunit/math/_fun_array_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
Quantity,
Unit,
fail_for_dimension_mismatch,
is_unitless, DimensionMismatchError, )
DimensionMismatchError, )
from .._misc import set_module_as

Shape = Union[int, Sequence[int]]
Expand Down Expand Up @@ -649,81 +649,25 @@ def arange(
out : quantity or array
Array of evenly spaced values.
"""

arg_len = len([x for x in [start, stop, step] if x is not None])

if arg_len == 1:
if stop is not None:
raise TypeError("Duplicate definition of 'stop'")
stop = start
start = 0
elif arg_len == 2:
if start is not None and stop is None:
stop = start
start = 0

elif arg_len > 3:
raise TypeError("Need between 1 and 3 non-keyword arguments")

# default values
if start is None:
start = 0
if step is None:
step = 1

if stop is None:
raise TypeError("Missing stop argument.")
if stop is not None and not is_unitless(stop):
start = Quantity(start, dim=stop.dim)

fail_for_dimension_mismatch(
start,
stop,
error_message="Start value {start} and stop value {stop} have to have the same units.",
start=start,
stop=stop,
)
fail_for_dimension_mismatch(
stop,
step,
error_message="Stop value {stop} and step value {step} have to have the same units.",
stop=stop,
step=step,
)

unit = getattr(stop, "dim", DIMENSIONLESS)

if start == 0:
return Quantity(
jnp.arange(
start=start.value if isinstance(start, Quantity) else jnp.asarray(start),
stop=stop.value if isinstance(stop, Quantity) else jnp.asarray(stop),
step=step.value if isinstance(step, Quantity) else jnp.asarray(step),
dtype=dtype,
),
dim=unit,
) if unit != DIMENSIONLESS else jnp.arange(
start=start.value if isinstance(start, Quantity) else jnp.asarray(start),
stop=stop.value if isinstance(stop, Quantity) else jnp.asarray(stop),
step=step.value if isinstance(step, Quantity) else jnp.asarray(step),
dtype=dtype,
non_none_data = [d for d in (start, stop, step) if d is not None]
# checking the dimension of the data
assert len(non_none_data) > 0, 'At least one of start, stop, or step must be provided.'
d1 = non_none_data[0]
for d2 in non_none_data[1:]:
fail_for_dimension_mismatch(
d1, d2,
error_message="Start stop, and step value have to "
"have the same units. Got: {d1} {d2}",
d1=d1, d2=d2
)
else:
return Quantity(
jnp.arange(
start.value if isinstance(start, Quantity) else jnp.asarray(start),
stop=stop.value if isinstance(stop, Quantity) else jnp.asarray(stop),
step=step.value if isinstance(step, Quantity) else jnp.asarray(step),
dtype=dtype,
),
dim=unit,
) if unit != DIMENSIONLESS else jnp.arange(
start.value if isinstance(start, Quantity) else jnp.asarray(start),
stop=stop.value if isinstance(stop, Quantity) else jnp.asarray(stop),
step=step.value if isinstance(step, Quantity) else jnp.asarray(step),
dtype=dtype,
)

dim = d1.dim if isinstance(d1, Quantity) else None
# convert to array
start = start.value if isinstance(start, Quantity) else start
stop = stop.value if isinstance(stop, Quantity) else stop
step = step.value if isinstance(step, Quantity) else step
# compute
r = jnp.arange(start, stop, step, dtype=dtype)
return Quantity(r, dim=dim) if dim is not None else r


@set_module_as('brainunit.math')
Expand Down

0 comments on commit a1b421e

Please sign in to comment.