From c39110da715e216f3a1030e9812a663f0329cdb5 Mon Sep 17 00:00:00 2001 From: He Sichao <1310722434@qq.com> Date: Thu, 27 Jun 2024 10:59:06 +0800 Subject: [PATCH] Fix test bugs --- brainunit/_base.py | 5 +- brainunit/math/_numpy_accept_unitless.py | 5 +- brainunit/math/_numpy_array_manipulation.py | 14 ++---- brainunit/math/_numpy_change_unit.py | 3 +- brainunit/math/_numpy_indexing.py | 2 +- brainunit/math/_numpy_keep_unit.py | 33 +++++++++++++ brainunit/math/_numpy_remove_unit.py | 2 +- brainunit/math/_numpy_test.py | 51 +++++++++++---------- 8 files changed, 75 insertions(+), 40 deletions(-) diff --git a/brainunit/_base.py b/brainunit/_base.py index 8622c74..9d4cb39 100644 --- a/brainunit/_base.py +++ b/brainunit/_base.py @@ -551,7 +551,10 @@ def get_dim(obj) -> Dimension: The physical dimensions of the `obj`. """ try: - return obj.dim + try: + return obj.dim.dim + except: + return obj.dim except AttributeError: # The following is not very pretty, but it will avoid the costly # isinstance check for the common types diff --git a/brainunit/math/_numpy_accept_unitless.py b/brainunit/math/_numpy_accept_unitless.py index cac78b1..da17489 100644 --- a/brainunit/math/_numpy_accept_unitless.py +++ b/brainunit/math/_numpy_accept_unitless.py @@ -30,9 +30,12 @@ 'arctanh', 'cos', 'cosh', 'sin', 'sinc', 'sinh', 'tan', 'tanh', 'deg2rad', 'rad2deg', 'degrees', 'radians', 'angle', 'percentile', 'nanpercentile', 'quantile', 'nanquantile', + 'corrcoef', 'correlate', 'cov', + + # math funcs only accept unitless (unary) can return Quantity 'round', 'around', 'round_', 'rint', 'floor', 'ceil', 'trunc', 'fix', 'modf', 'frexp', - 'corrcoef', 'correlate', 'cov', + # math funcs only accept unitless (binary) 'hypot', 'arctan2', 'logaddexp', 'logaddexp2', diff --git a/brainunit/math/_numpy_array_manipulation.py b/brainunit/math/_numpy_array_manipulation.py index a8bfdf5..9efa6cd 100644 --- a/brainunit/math/_numpy_array_manipulation.py +++ b/brainunit/math/_numpy_array_manipulation.py @@ -1068,7 +1068,7 @@ def compress( axis: Optional[int] = None, *, size: Optional[int] = None, - fill_value: Optional[jax.typing.ArrayLike] = None, + fill_value: Optional[jax.typing.ArrayLike] = 0, ) -> Union[Array, Quantity]: """ Return selected slices of a quantity or an array along given axis. @@ -1096,7 +1096,7 @@ def compress( """ assert not isinstance(condition, Quantity), f'condition must be an array_like. But got {condition}' if isinstance(a, Quantity): - if fill_value is not None: + if fill_value != 0: fail_for_dimension_mismatch(fill_value, a) fill_value = fill_value.value else: @@ -1113,7 +1113,7 @@ def extract( arr: Union[Array, Quantity], *, size: Optional[int] = None, - fill_value: Optional[jax.typing.ArrayLike | Quantity] = None, + fill_value: Optional[jax.typing.ArrayLike | Quantity] = 0, ) -> Array | Quantity: """ Return the elements of an array that satisfy some condition. @@ -1137,7 +1137,7 @@ def extract( """ assert not isinstance(condition, Quantity), f'condition must be an array_like. But got {condition}' if isinstance(arr, Quantity): - if fill_value is not None: + if fill_value != 0: fail_for_dimension_mismatch(fill_value, arr) fill_value = fill_value.value else: @@ -1196,7 +1196,6 @@ def argsort( def argmax( a: Union[Array, Quantity], axis: Optional[int] = None, - keepdims: Optional[bool] = None ) -> Array: """ Returns indices of the max value along an axis. @@ -1207,16 +1206,13 @@ def argmax( Input data. axis : int, optional By default, the index is into the flattened array, otherwise along the specified axis. - keepdims : bool, optional - If this is set to True, the axes which are reduced are left in the result as dimensions with size one. With this - option, the result will broadcast correctly against the input array. Returns ------- res : ndarray Array of indices into the array. It has the same shape as `a.shape` with the dimension along `axis` removed. """ - return _fun_remove_unit_unary(jnp.argmax, a, axis=axis, keepdim=keepdims) + return _fun_remove_unit_unary(jnp.argmax, a, axis=axis) @set_module_as('brainunit.math') diff --git a/brainunit/math/_numpy_change_unit.py b/brainunit/math/_numpy_change_unit.py index de4d75e..3f46d44 100644 --- a/brainunit/math/_numpy_change_unit.py +++ b/brainunit/math/_numpy_change_unit.py @@ -860,7 +860,7 @@ def floor_divide( out = floor(`x`/`y`) This is a scalar if both `x` and `y` are scalars. """ - return _fun_change_unit_binary(jnp.floor_divide, lambda ux, uy: ux // uy, x, y) + return _fun_change_unit_binary(jnp.floor_divide, lambda ux, uy: ux / uy, x, y) @set_module_as('brainunit.math') @@ -910,4 +910,3 @@ def float_power( return _return_check_unitless(Quantity(jnp.float_power(x, y), dim=x ** y)) else: return jnp.float_power(x, y) - diff --git a/brainunit/math/_numpy_indexing.py b/brainunit/math/_numpy_indexing.py index 2db2413..16b1c3f 100644 --- a/brainunit/math/_numpy_indexing.py +++ b/brainunit/math/_numpy_indexing.py @@ -228,4 +228,4 @@ def select( """ for cond in condlist: assert not isinstance(cond, Quantity), "condlist should not contain Quantity." - return _fun_keep_unit_sequence(functools.partial(jnp.select, condlist), choicelist, default) + return _fun_keep_unit_sequence(functools.partial(jnp.select, condlist), choicelist, default=default) diff --git a/brainunit/math/_numpy_keep_unit.py b/brainunit/math/_numpy_keep_unit.py index c763116..211b037 100644 --- a/brainunit/math/_numpy_keep_unit.py +++ b/brainunit/math/_numpy_keep_unit.py @@ -33,6 +33,7 @@ # math funcs keep unit (binary) 'fmod', 'mod', 'copysign', 'remainder', 'maximum', 'minimum', 'fmax', 'fmin', 'lcm', 'gcd', + 'remainder', # math funcs keep unit (n-ary) 'interp', 'clip', 'histogram', @@ -1613,3 +1614,35 @@ def nextafter( This is a scalar if both `x` and `y` are scalars. """ return _fun_match_unit_binary(jnp.nextafter, x, y, *args, **kwargs) + +@set_module_as('brainunit.math') +def remainder( + x: Union[Quantity, jax.typing.ArrayLike], + y: Union[Quantity, jax.typing.ArrayLike] +) -> Union[Quantity, jax.Array]: + """ + Returns the element-wise remainder of division. + + Computes the remainder complementary to the `floor_divide` function. It is + equivalent to the Python modulus operator``x1 % x2`` and has the same sign + as the divisor `x2`. The MATLAB function equivalent to ``np.remainder`` + is ``mod``. + + Parameters + ---------- + x : array_like, Quantity + Dividend array. + y : array_like, Quantity + Divisor array. + If ``x1.shape != x2.shape``, they must be broadcastable to a common + shape (which becomes the shape of the output). + + Returns + ------- + out : ndarray, Quantity + The element-wise remainder of the quotient ``floor_divide(x1, x2)``. + This is a scalar if both `x1` and `x2` are scalars. + + This is a Quantity if division of `x1` by `x2` is not dimensionless. + """ + return _fun_keep_unit_binary(jnp.remainder, x, y) diff --git a/brainunit/math/_numpy_remove_unit.py b/brainunit/math/_numpy_remove_unit.py index f48a84c..92e90d1 100644 --- a/brainunit/math/_numpy_remove_unit.py +++ b/brainunit/math/_numpy_remove_unit.py @@ -20,7 +20,7 @@ import jax.numpy as jnp from jax import Array -from .._base import Quantity, DIMENSIONLESS, fail_for_dimension_mismatch +from .._base import Quantity, fail_for_dimension_mismatch, DIMENSIONLESS from .._misc import set_module_as __all__ = [ diff --git a/brainunit/math/_numpy_test.py b/brainunit/math/_numpy_test.py index e582038..4765b67 100644 --- a/brainunit/math/_numpy_test.py +++ b/brainunit/math/_numpy_test.py @@ -373,7 +373,7 @@ def test_round(self): self.assertTrue(jnp.all(result == jnp.round(array))) q = bu.Quantity([1.123, 2.567, 3.891], bu.second) - result_q = bu.math.round(q) + result_q = bu.math.round(q, unit_to_scale=bu.second) expected_q = jnp.round(jnp.array([1.123, 2.567, 3.891])) * bu.second assert_quantity(result_q, expected_q.value, bu.second) @@ -383,7 +383,7 @@ def test_rint(self): self.assertTrue(jnp.all(result == jnp.rint(array))) q = bu.Quantity([1.5, 2.3, 3.8], bu.second) - result_q = bu.math.rint(q) + result_q = bu.math.rint(q, unit_to_scale=bu.second) expected_q = jnp.rint(jnp.array([1.5, 2.3, 3.8])) * bu.second assert_quantity(result_q, expected_q.value, bu.second) @@ -393,7 +393,7 @@ def test_floor(self): self.assertTrue(jnp.all(result == jnp.floor(array))) q = bu.Quantity([1.5, 2.3, 3.8], bu.second) - result_q = bu.math.floor(q) + result_q = bu.math.floor(q, unit_to_scale=bu.second) expected_q = jnp.floor(jnp.array([1.5, 2.3, 3.8])) assert_quantity(result_q, expected_q, bu.second) @@ -713,6 +713,16 @@ def test_gcd(self): expected_q = jnp.gcd(jnp.array([4, 5, 6]), jnp.array([2, 3, 4])) * bu.second assert_quantity(result_q, expected_q.value, bu.second) + def test_remainder(self): + result = bu.math.remainder(jnp.array([5, 7]), jnp.array([2, 3])) + self.assertTrue(jnp.all(result == jnp.remainder(jnp.array([5, 7]), jnp.array([2, 3])))) + + q1 = [5, 7] * bu.second + q2 = [2, 3] * bu.second + result_q = bu.math.remainder(q1, q2) + expected_q = jnp.remainder(jnp.array([5, 7]), jnp.array([2, 3])) * bu.second + assert_quantity(result_q, expected_q.value, bu.second) + class TestMathFuncsKeepUnitUnary2(unittest.TestCase): @@ -946,15 +956,6 @@ def test_divmod(self): expected = jnp.divmod(jnp.array([5, 6]), jnp.array([2, 3])) self.assertTrue(jnp.all(result[0] == expected[0]) and jnp.all(result[1] == expected[1])) - def test_remainder(self): - result = bu.math.remainder(jnp.array([5, 7]), jnp.array([2, 3])) - self.assertTrue(jnp.all(result == jnp.remainder(jnp.array([5, 7]), jnp.array([2, 3])))) - - q1 = [5, 7] * (bu.second ** 2) - q2 = [2, 3] * bu.second - result_q = bu.math.remainder(q1, q2) - expected_q = jnp.remainder(jnp.array([5, 7]), jnp.array([2, 3])) * bu.second - assert_quantity(result_q, expected_q.value, bu.second) def test_convolve(self): result = bu.math.convolve(jnp.array([1, 2, 3]), jnp.array([4, 5, 6])) @@ -1625,7 +1626,7 @@ def test_append(self): self.assertTrue(jnp.all(result == jnp.append(array, 3))) q = [0, 1, 2] * bu.second - result_q = bu.math.append(q, 3) + result_q = bu.math.append(q, 3 * bu.second) expected_q = jnp.append(jnp.array([0, 1, 2]), 3) assert_quantity(result_q, expected_q, bu.second) @@ -1901,7 +1902,7 @@ def test_searchsorted(self): self.assertTrue(result == jnp.searchsorted(array, 2)) q = [0, 1, 2] * bu.second - result_q = bu.math.searchsorted(q, 2) + result_q = bu.math.searchsorted(q, 2 * bu.second) expected_q = jnp.searchsorted(jnp.array([0, 1, 2]), 2) assert result_q == expected_q @@ -1932,7 +1933,7 @@ def test_bitwise_not(self): result = bu.math.bitwise_not(jnp.array([0b1100])) self.assertTrue(jnp.all(result == jnp.bitwise_not(jnp.array([0b1100])))) - with pytest.raises(ValueError): + with pytest.raises(AssertionError): q = [0b1100] * bu.second result_q = bu.math.bitwise_not(q) @@ -1940,7 +1941,7 @@ def test_invert(self): result = bu.math.invert(jnp.array([0b1100])) self.assertTrue(jnp.all(result == jnp.invert(jnp.array([0b1100])))) - with pytest.raises(ValueError): + with pytest.raises(AssertionError): q = [0b1100] * bu.second result_q = bu.math.invert(q) @@ -1951,7 +1952,7 @@ def test_bitwise_and(self): result = bu.math.bitwise_and(jnp.array([0b1100]), jnp.array([0b1010])) self.assertTrue(jnp.all(result == jnp.bitwise_and(jnp.array([0b1100]), jnp.array([0b1010])))) - with pytest.raises(ValueError): + with pytest.raises(AssertionError): q1 = [0b1100] * bu.second q2 = [0b1010] * bu.second result_q = bu.math.bitwise_and(q1, q2) @@ -1960,7 +1961,7 @@ def test_bitwise_or(self): result = bu.math.bitwise_or(jnp.array([0b1100]), jnp.array([0b1010])) self.assertTrue(jnp.all(result == jnp.bitwise_or(jnp.array([0b1100]), jnp.array([0b1010])))) - with pytest.raises(ValueError): + with pytest.raises(AssertionError): q1 = [0b1100] * bu.second q2 = [0b1010] * bu.second result_q = bu.math.bitwise_or(q1, q2) @@ -1969,7 +1970,7 @@ def test_bitwise_xor(self): result = bu.math.bitwise_xor(jnp.array([0b1100]), jnp.array([0b1010])) self.assertTrue(jnp.all(result == jnp.bitwise_xor(jnp.array([0b1100]), jnp.array([0b1010])))) - with pytest.raises(ValueError): + with pytest.raises(AssertionError): q1 = [0b1100] * bu.second q2 = [0b1010] * bu.second result_q = bu.math.bitwise_xor(q1, q2) @@ -1978,7 +1979,7 @@ def test_left_shift(self): result = bu.math.left_shift(jnp.array([0b1100]), 2) self.assertTrue(jnp.all(result == jnp.left_shift(jnp.array([0b1100]), 2))) - with pytest.raises(ValueError): + with pytest.raises(AssertionError): q = [0b1100] * bu.second result_q = bu.math.left_shift(q, 2) @@ -1986,7 +1987,7 @@ def test_right_shift(self): result = bu.math.right_shift(jnp.array([0b1100]), 2) self.assertTrue(jnp.all(result == jnp.right_shift(jnp.array([0b1100]), 2))) - with pytest.raises(ValueError): + with pytest.raises(AssertionError): q = [0b1100] * bu.second result_q = bu.math.right_shift(q, 2) @@ -1996,7 +1997,7 @@ def test_all(self): result = bu.math.all(jnp.array([True, True, True])) self.assertTrue(result == jnp.all(jnp.array([True, True, True]))) - with pytest.raises(ValueError): + with pytest.raises(AssertionError): q = [True, True, True] * bu.second result_q = bu.math.all(q) @@ -2004,7 +2005,7 @@ def test_any(self): result = bu.math.any(jnp.array([False, True, False])) self.assertTrue(result == jnp.any(jnp.array([False, True, False]))) - with pytest.raises(ValueError): + with pytest.raises(AssertionError): q = [False, True, False] * bu.second result_q = bu.math.any(q) @@ -2012,7 +2013,7 @@ def test_logical_not(self): result = bu.math.logical_not(jnp.array([True, False])) self.assertTrue(jnp.all(result == jnp.logical_not(jnp.array([True, False])))) - with pytest.raises(ValueError): + with pytest.raises(AssertionError): q = [True, False] * bu.second result_q = bu.math.logical_not(q) @@ -2153,7 +2154,7 @@ def test_where(self): self.assertTrue(jnp.all(result == jnp.where(array > 2, array, 0))) q = [1, 2, 3, 4, 5] * bu.second - result_q = bu.math.where(q > 2 * bu.second, q, 0) + result_q = bu.math.where(q > 2 * bu.second, q.to_value(bu.second), 0) expected_q = jnp.where(jnp.array([1, 2, 3, 4, 5]) > 2, jnp.array([1, 2, 3, 4, 5]), 0) assert_quantity(result_q, expected_q, bu.second)