Skip to content

Commit

Permalink
Fix test bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
Routhleck committed Jun 27, 2024
1 parent e61fa4c commit c39110d
Show file tree
Hide file tree
Showing 8 changed files with 75 additions and 40 deletions.
5 changes: 4 additions & 1 deletion brainunit/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -551,7 +551,10 @@ def get_dim(obj) -> Dimension:
The physical dimensions of the `obj`.
"""
try:
return obj.dim
try:
return obj.dim.dim
except:
return obj.dim
except AttributeError:
# The following is not very pretty, but it will avoid the costly
# isinstance check for the common types
Expand Down
5 changes: 4 additions & 1 deletion brainunit/math/_numpy_accept_unitless.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,12 @@
'arctanh', 'cos', 'cosh', 'sin', 'sinc', 'sinh', 'tan',
'tanh', 'deg2rad', 'rad2deg', 'degrees', 'radians', 'angle',
'percentile', 'nanpercentile', 'quantile', 'nanquantile',
'corrcoef', 'correlate', 'cov',

# math funcs only accept unitless (unary) can return Quantity
'round', 'around', 'round_', 'rint',
'floor', 'ceil', 'trunc', 'fix', 'modf', 'frexp',
'corrcoef', 'correlate', 'cov',


# math funcs only accept unitless (binary)
'hypot', 'arctan2', 'logaddexp', 'logaddexp2',
Expand Down
14 changes: 5 additions & 9 deletions brainunit/math/_numpy_array_manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1068,7 +1068,7 @@ def compress(
axis: Optional[int] = None,
*,
size: Optional[int] = None,
fill_value: Optional[jax.typing.ArrayLike] = None,
fill_value: Optional[jax.typing.ArrayLike] = 0,
) -> Union[Array, Quantity]:
"""
Return selected slices of a quantity or an array along given axis.
Expand Down Expand Up @@ -1096,7 +1096,7 @@ def compress(
"""
assert not isinstance(condition, Quantity), f'condition must be an array_like. But got {condition}'
if isinstance(a, Quantity):
if fill_value is not None:
if fill_value != 0:
fail_for_dimension_mismatch(fill_value, a)
fill_value = fill_value.value
else:
Expand All @@ -1113,7 +1113,7 @@ def extract(
arr: Union[Array, Quantity],
*,
size: Optional[int] = None,
fill_value: Optional[jax.typing.ArrayLike | Quantity] = None,
fill_value: Optional[jax.typing.ArrayLike | Quantity] = 0,
) -> Array | Quantity:
"""
Return the elements of an array that satisfy some condition.
Expand All @@ -1137,7 +1137,7 @@ def extract(
"""
assert not isinstance(condition, Quantity), f'condition must be an array_like. But got {condition}'
if isinstance(arr, Quantity):
if fill_value is not None:
if fill_value != 0:
fail_for_dimension_mismatch(fill_value, arr)
fill_value = fill_value.value
else:
Expand Down Expand Up @@ -1196,7 +1196,6 @@ def argsort(
def argmax(
a: Union[Array, Quantity],
axis: Optional[int] = None,
keepdims: Optional[bool] = None
) -> Array:
"""
Returns indices of the max value along an axis.
Expand All @@ -1207,16 +1206,13 @@ def argmax(
Input data.
axis : int, optional
By default, the index is into the flattened array, otherwise along the specified axis.
keepdims : bool, optional
If this is set to True, the axes which are reduced are left in the result as dimensions with size one. With this
option, the result will broadcast correctly against the input array.
Returns
-------
res : ndarray
Array of indices into the array. It has the same shape as `a.shape` with the dimension along `axis` removed.
"""
return _fun_remove_unit_unary(jnp.argmax, a, axis=axis, keepdim=keepdims)
return _fun_remove_unit_unary(jnp.argmax, a, axis=axis)


@set_module_as('brainunit.math')
Expand Down
3 changes: 1 addition & 2 deletions brainunit/math/_numpy_change_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -860,7 +860,7 @@ def floor_divide(
out = floor(`x`/`y`)
This is a scalar if both `x` and `y` are scalars.
"""
return _fun_change_unit_binary(jnp.floor_divide, lambda ux, uy: ux // uy, x, y)
return _fun_change_unit_binary(jnp.floor_divide, lambda ux, uy: ux / uy, x, y)


@set_module_as('brainunit.math')
Expand Down Expand Up @@ -910,4 +910,3 @@ def float_power(
return _return_check_unitless(Quantity(jnp.float_power(x, y), dim=x ** y))
else:
return jnp.float_power(x, y)

2 changes: 1 addition & 1 deletion brainunit/math/_numpy_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,4 +228,4 @@ def select(
"""
for cond in condlist:
assert not isinstance(cond, Quantity), "condlist should not contain Quantity."
return _fun_keep_unit_sequence(functools.partial(jnp.select, condlist), choicelist, default)
return _fun_keep_unit_sequence(functools.partial(jnp.select, condlist), choicelist, default=default)
33 changes: 33 additions & 0 deletions brainunit/math/_numpy_keep_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
# math funcs keep unit (binary)
'fmod', 'mod', 'copysign', 'remainder',
'maximum', 'minimum', 'fmax', 'fmin', 'lcm', 'gcd',
'remainder',

# math funcs keep unit (n-ary)
'interp', 'clip', 'histogram',
Expand Down Expand Up @@ -1613,3 +1614,35 @@ def nextafter(
This is a scalar if both `x` and `y` are scalars.
"""
return _fun_match_unit_binary(jnp.nextafter, x, y, *args, **kwargs)

@set_module_as('brainunit.math')
def remainder(
x: Union[Quantity, jax.typing.ArrayLike],
y: Union[Quantity, jax.typing.ArrayLike]
) -> Union[Quantity, jax.Array]:
"""
Returns the element-wise remainder of division.
Computes the remainder complementary to the `floor_divide` function. It is
equivalent to the Python modulus operator``x1 % x2`` and has the same sign
as the divisor `x2`. The MATLAB function equivalent to ``np.remainder``
is ``mod``.
Parameters
----------
x : array_like, Quantity
Dividend array.
y : array_like, Quantity
Divisor array.
If ``x1.shape != x2.shape``, they must be broadcastable to a common
shape (which becomes the shape of the output).
Returns
-------
out : ndarray, Quantity
The element-wise remainder of the quotient ``floor_divide(x1, x2)``.
This is a scalar if both `x1` and `x2` are scalars.
This is a Quantity if division of `x1` by `x2` is not dimensionless.
"""
return _fun_keep_unit_binary(jnp.remainder, x, y)
2 changes: 1 addition & 1 deletion brainunit/math/_numpy_remove_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import jax.numpy as jnp
from jax import Array

from .._base import Quantity, DIMENSIONLESS, fail_for_dimension_mismatch
from .._base import Quantity, fail_for_dimension_mismatch, DIMENSIONLESS
from .._misc import set_module_as

__all__ = [
Expand Down
51 changes: 26 additions & 25 deletions brainunit/math/_numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,7 @@ def test_round(self):
self.assertTrue(jnp.all(result == jnp.round(array)))

q = bu.Quantity([1.123, 2.567, 3.891], bu.second)
result_q = bu.math.round(q)
result_q = bu.math.round(q, unit_to_scale=bu.second)
expected_q = jnp.round(jnp.array([1.123, 2.567, 3.891])) * bu.second
assert_quantity(result_q, expected_q.value, bu.second)

Expand All @@ -383,7 +383,7 @@ def test_rint(self):
self.assertTrue(jnp.all(result == jnp.rint(array)))

q = bu.Quantity([1.5, 2.3, 3.8], bu.second)
result_q = bu.math.rint(q)
result_q = bu.math.rint(q, unit_to_scale=bu.second)
expected_q = jnp.rint(jnp.array([1.5, 2.3, 3.8])) * bu.second
assert_quantity(result_q, expected_q.value, bu.second)

Expand All @@ -393,7 +393,7 @@ def test_floor(self):
self.assertTrue(jnp.all(result == jnp.floor(array)))

q = bu.Quantity([1.5, 2.3, 3.8], bu.second)
result_q = bu.math.floor(q)
result_q = bu.math.floor(q, unit_to_scale=bu.second)
expected_q = jnp.floor(jnp.array([1.5, 2.3, 3.8]))
assert_quantity(result_q, expected_q, bu.second)

Expand Down Expand Up @@ -713,6 +713,16 @@ def test_gcd(self):
expected_q = jnp.gcd(jnp.array([4, 5, 6]), jnp.array([2, 3, 4])) * bu.second
assert_quantity(result_q, expected_q.value, bu.second)

def test_remainder(self):
result = bu.math.remainder(jnp.array([5, 7]), jnp.array([2, 3]))
self.assertTrue(jnp.all(result == jnp.remainder(jnp.array([5, 7]), jnp.array([2, 3]))))

q1 = [5, 7] * bu.second
q2 = [2, 3] * bu.second
result_q = bu.math.remainder(q1, q2)
expected_q = jnp.remainder(jnp.array([5, 7]), jnp.array([2, 3])) * bu.second
assert_quantity(result_q, expected_q.value, bu.second)


class TestMathFuncsKeepUnitUnary2(unittest.TestCase):

Expand Down Expand Up @@ -946,15 +956,6 @@ def test_divmod(self):
expected = jnp.divmod(jnp.array([5, 6]), jnp.array([2, 3]))
self.assertTrue(jnp.all(result[0] == expected[0]) and jnp.all(result[1] == expected[1]))

def test_remainder(self):
result = bu.math.remainder(jnp.array([5, 7]), jnp.array([2, 3]))
self.assertTrue(jnp.all(result == jnp.remainder(jnp.array([5, 7]), jnp.array([2, 3]))))

q1 = [5, 7] * (bu.second ** 2)
q2 = [2, 3] * bu.second
result_q = bu.math.remainder(q1, q2)
expected_q = jnp.remainder(jnp.array([5, 7]), jnp.array([2, 3])) * bu.second
assert_quantity(result_q, expected_q.value, bu.second)

def test_convolve(self):
result = bu.math.convolve(jnp.array([1, 2, 3]), jnp.array([4, 5, 6]))
Expand Down Expand Up @@ -1625,7 +1626,7 @@ def test_append(self):
self.assertTrue(jnp.all(result == jnp.append(array, 3)))

q = [0, 1, 2] * bu.second
result_q = bu.math.append(q, 3)
result_q = bu.math.append(q, 3 * bu.second)
expected_q = jnp.append(jnp.array([0, 1, 2]), 3)
assert_quantity(result_q, expected_q, bu.second)

Expand Down Expand Up @@ -1901,7 +1902,7 @@ def test_searchsorted(self):
self.assertTrue(result == jnp.searchsorted(array, 2))

q = [0, 1, 2] * bu.second
result_q = bu.math.searchsorted(q, 2)
result_q = bu.math.searchsorted(q, 2 * bu.second)
expected_q = jnp.searchsorted(jnp.array([0, 1, 2]), 2)
assert result_q == expected_q

Expand Down Expand Up @@ -1932,15 +1933,15 @@ def test_bitwise_not(self):
result = bu.math.bitwise_not(jnp.array([0b1100]))
self.assertTrue(jnp.all(result == jnp.bitwise_not(jnp.array([0b1100]))))

with pytest.raises(ValueError):
with pytest.raises(AssertionError):
q = [0b1100] * bu.second
result_q = bu.math.bitwise_not(q)

def test_invert(self):
result = bu.math.invert(jnp.array([0b1100]))
self.assertTrue(jnp.all(result == jnp.invert(jnp.array([0b1100]))))

with pytest.raises(ValueError):
with pytest.raises(AssertionError):
q = [0b1100] * bu.second
result_q = bu.math.invert(q)

Expand All @@ -1951,7 +1952,7 @@ def test_bitwise_and(self):
result = bu.math.bitwise_and(jnp.array([0b1100]), jnp.array([0b1010]))
self.assertTrue(jnp.all(result == jnp.bitwise_and(jnp.array([0b1100]), jnp.array([0b1010]))))

with pytest.raises(ValueError):
with pytest.raises(AssertionError):
q1 = [0b1100] * bu.second
q2 = [0b1010] * bu.second
result_q = bu.math.bitwise_and(q1, q2)
Expand All @@ -1960,7 +1961,7 @@ def test_bitwise_or(self):
result = bu.math.bitwise_or(jnp.array([0b1100]), jnp.array([0b1010]))
self.assertTrue(jnp.all(result == jnp.bitwise_or(jnp.array([0b1100]), jnp.array([0b1010]))))

with pytest.raises(ValueError):
with pytest.raises(AssertionError):
q1 = [0b1100] * bu.second
q2 = [0b1010] * bu.second
result_q = bu.math.bitwise_or(q1, q2)
Expand All @@ -1969,7 +1970,7 @@ def test_bitwise_xor(self):
result = bu.math.bitwise_xor(jnp.array([0b1100]), jnp.array([0b1010]))
self.assertTrue(jnp.all(result == jnp.bitwise_xor(jnp.array([0b1100]), jnp.array([0b1010]))))

with pytest.raises(ValueError):
with pytest.raises(AssertionError):
q1 = [0b1100] * bu.second
q2 = [0b1010] * bu.second
result_q = bu.math.bitwise_xor(q1, q2)
Expand All @@ -1978,15 +1979,15 @@ def test_left_shift(self):
result = bu.math.left_shift(jnp.array([0b1100]), 2)
self.assertTrue(jnp.all(result == jnp.left_shift(jnp.array([0b1100]), 2)))

with pytest.raises(ValueError):
with pytest.raises(AssertionError):
q = [0b1100] * bu.second
result_q = bu.math.left_shift(q, 2)

def test_right_shift(self):
result = bu.math.right_shift(jnp.array([0b1100]), 2)
self.assertTrue(jnp.all(result == jnp.right_shift(jnp.array([0b1100]), 2)))

with pytest.raises(ValueError):
with pytest.raises(AssertionError):
q = [0b1100] * bu.second
result_q = bu.math.right_shift(q, 2)

Expand All @@ -1996,23 +1997,23 @@ def test_all(self):
result = bu.math.all(jnp.array([True, True, True]))
self.assertTrue(result == jnp.all(jnp.array([True, True, True])))

with pytest.raises(ValueError):
with pytest.raises(AssertionError):
q = [True, True, True] * bu.second
result_q = bu.math.all(q)

def test_any(self):
result = bu.math.any(jnp.array([False, True, False]))
self.assertTrue(result == jnp.any(jnp.array([False, True, False])))

with pytest.raises(ValueError):
with pytest.raises(AssertionError):
q = [False, True, False] * bu.second
result_q = bu.math.any(q)

def test_logical_not(self):
result = bu.math.logical_not(jnp.array([True, False]))
self.assertTrue(jnp.all(result == jnp.logical_not(jnp.array([True, False]))))

with pytest.raises(ValueError):
with pytest.raises(AssertionError):
q = [True, False] * bu.second
result_q = bu.math.logical_not(q)

Expand Down Expand Up @@ -2153,7 +2154,7 @@ def test_where(self):
self.assertTrue(jnp.all(result == jnp.where(array > 2, array, 0)))

q = [1, 2, 3, 4, 5] * bu.second
result_q = bu.math.where(q > 2 * bu.second, q, 0)
result_q = bu.math.where(q > 2 * bu.second, q.to_value(bu.second), 0)
expected_q = jnp.where(jnp.array([1, 2, 3, 4, 5]) > 2, jnp.array([1, 2, 3, 4, 5]), 0)
assert_quantity(result_q, expected_q, bu.second)

Expand Down

0 comments on commit c39110d

Please sign in to comment.