From b59bbb3fa7465f6e01dcef3c570bc00e026673a6 Mon Sep 17 00:00:00 2001 From: He Sichao <1310722434@qq.com> Date: Wed, 26 Jun 2024 22:58:12 +0800 Subject: [PATCH] Fix the logic of `numpy_accept_unitless.py` and Add their tests --- brainunit/math/_numpy_accept_unitless.py | 106 +++++++++++--------- brainunit/math/_numpy_test.py | 122 +++++++++++++++++++++++ 2 files changed, 180 insertions(+), 48 deletions(-) diff --git a/brainunit/math/_numpy_accept_unitless.py b/brainunit/math/_numpy_accept_unitless.py index 68be1a1..5fdeadb 100644 --- a/brainunit/math/_numpy_accept_unitless.py +++ b/brainunit/math/_numpy_accept_unitless.py @@ -69,7 +69,7 @@ def _func_accept_unitless_unary( value=x, unit_to_scale=unit_to_scale ) - return Quantity(func(x / unit_to_scale, *args, **kwargs), unit=unit_to_scale) + return func(x.to_value(unit_to_scale), *args, **kwargs) else: assert unit_to_scale is None, f'Unit should be None for the function "{func}" when "x" is not a Quantity.' return func(x, *args, **kwargs) @@ -889,23 +889,40 @@ def frexp( # ---------------------------------------- -def _fun_accept_unitless_binary(func, x, y, *args, **kwargs): - x_value = x.value if isinstance(x, Quantity) else x - y_value = y.value if isinstance(y, Quantity) else y - if isinstance(x, Quantity) or isinstance(y, Quantity): - fail_for_dimension_mismatch( - x, - y, - error_message="%s expects a dimensionless argument but got {value}" % func.__name__, - value=x, - ) - fail_for_dimension_mismatch( - y, - error_message="%s expects a dimensionless argument but got {value}" % func.__name__, - value=y, - ) - return func(jnp.array(x_value), jnp.array(y_value), *args, **kwargs) +def _fun_accept_unitless_binary( + func: Callable, + x: jax.typing.ArrayLike | Quantity, + y: jax.typing.ArrayLike | Quantity, + *args, + unit_to_scale: Optional[Unit] = None, + **kwargs): + if isinstance(x, Quantity) and isinstance(y, Quantity): + if unit_to_scale is None: + assert x.is_unitless, (f'Input should be unitless for the function "{func}" ' + f'when scaling "unit_to_scale" is not provided.') + assert y.is_unitless, (f'Input should be unitless for the function "{func}" ' + f'when scaling "unit_to_scale" is not provided.') + x = x.value + y = y.value + return func(x, y, *args, **kwargs) + else: + fail_for_dimension_mismatch( + x, + unit_to_scale, + error_message="Unit mismatch: {value} != {unit_to_scale}", + value=x, + unit_to_scale=unit_to_scale + ) + fail_for_dimension_mismatch( + y, + unit_to_scale, + error_message="Unit mismatch: {value} != {unit_to_scale}", + value=y, + unit_to_scale=unit_to_scale + ) + return func(x.to_value(unit_to_scale), y.to_value(unit_to_scale), *args, **kwargs) else: + assert unit_to_scale is None, f'Unit should be None for the function "{func}" when "x" and "y" are not Quantities.' return func(x, y, *args, **kwargs) @@ -913,6 +930,7 @@ def _fun_accept_unitless_binary(func, x, y, *args, **kwargs): def hypot( x: Union[Array, Quantity], y: Union[Array, Quantity], + unit_to_scale: Optional[Unit] = None, ) -> Array | Quantity: """ Given the “legs” of a right triangle, return its hypotenuse. @@ -929,13 +947,14 @@ def hypot( out : jax.Array Output array. """ - return _fun_accept_unitless_binary(jnp.hypot, x, y) + return _fun_accept_unitless_binary(jnp.hypot, x, y, unit_to_scale=unit_to_scale) @set_module_as('brainunit.math') def arctan2( x: Union[Array, Quantity], y: Union[Array, Quantity], + unit_to_scale: Optional[Unit] = None, ) -> Array | Quantity: """ Element-wise arc tangent of `x1/x2` choosing the quadrant correctly. @@ -952,13 +971,14 @@ def arctan2( out : jax.Array Output array. """ - return _fun_accept_unitless_binary(jnp.arctan2, x, y) + return _fun_accept_unitless_binary(jnp.arctan2, x, y, unit_to_scale=unit_to_scale) @set_module_as('brainunit.math') def logaddexp( x: Union[Array, Quantity], y: Union[Array, Quantity], + unit_to_scale: Optional[Unit] = None, ) -> Array | Quantity: """ Logarithm of the sum of exponentiations of the inputs. @@ -975,11 +995,15 @@ def logaddexp( out : jax.Array Output array. """ - return _fun_accept_unitless_binary(jnp.logaddexp, x, y) + return _fun_accept_unitless_binary(jnp.logaddexp, x, y, unit_to_scale=unit_to_scale) @set_module_as('brainunit.math') -def logaddexp2(x: Union[Array, Quantity], y: Union[Array, Quantity]) -> Array | Quantity: +def logaddexp2( + x: Union[Array, Quantity], + y: Union[Array, Quantity], + unit_to_scale: Optional[Unit] = None, +) -> Array | Quantity: """ Logarithm of the sum of exponentiations of the inputs in base-2. @@ -995,7 +1019,7 @@ def logaddexp2(x: Union[Array, Quantity], y: Union[Array, Quantity]) -> Array | out : jax.Array Output array. """ - return _fun_accept_unitless_binary(jnp.logaddexp2, x, y) + return _fun_accept_unitless_binary(jnp.logaddexp2, x, y, unit_to_scale=unit_to_scale) @set_module_as('brainunit.math') @@ -1003,9 +1027,9 @@ def percentile( a: Union[Array, Quantity], q: Union[Array, Quantity], axis: Optional[Union[int, Tuple[int]]] = None, - overwrite_input: Optional[bool] = None, method: str = 'linear', keepdims: Optional[bool] = False, + unit_to_scale: Optional[Unit] = None, ) -> Array: """ Compute the q-th percentile of the data along the specified axis. @@ -1018,11 +1042,6 @@ def percentile( Input array or Quantity. q : array_like, Quantity Percentile or sequence of percentiles to compute, which must be between 0 and 100 inclusive. - out : array_like, Quantity, optional - Alternative output array in which to place the result. - It must have the same shape and buffer length as the expected output but the type will be cast if necessary. - overwrite_input : bool, optional - If True, then allow the input array a to be modified by intermediate calculations, to save memory. method : str, optional This parameter specifies the method to use for estimating the percentile. There are many different methods, some unique to NumPy. @@ -1055,8 +1074,8 @@ def percentile( Output array. """ return _fun_accept_unitless_binary( - jnp.percentile, a, q, axis=axis, overwrite_input=overwrite_input, - method=method, keepdims=keepdims + jnp.percentile, a, q, axis=axis, + method=method, keepdims=keepdims, unit_to_scale=unit_to_scale ) @@ -1065,9 +1084,9 @@ def nanpercentile( a: Union[Array, Quantity], q: Union[Array, Quantity], axis: Optional[Union[int, Tuple[int]]] = None, - overwrite_input: Optional[bool] = None, method: str = 'linear', keepdims: Optional[bool] = False, + unit_to_scale: Optional[Unit] = None, ) -> Array: """ Compute the q-th percentile of the data along the specified axis, while ignoring nan values. @@ -1080,11 +1099,6 @@ def nanpercentile( Input array or Quantity. q : array_like, Quantity Percentile or sequence of percentiles to compute, which must be between 0 and 100 inclusive. - out : array_like, Quantity, optional - Alternative output array in which to place the result. - It must have the same shape and buffer length as the expected output but the type will be cast if necessary. - overwrite_input : bool, optional - If True, then allow the input array a to be modified by intermediate calculations, to save memory. method : str, optional This parameter specifies the method to use for estimating the percentile. There are many different methods, some unique to NumPy. @@ -1118,8 +1132,8 @@ def nanpercentile( """ return _fun_accept_unitless_binary( jnp.nanpercentile, a, q, - axis=axis, ooverwrite_input=overwrite_input, - method=method, keepdims=keepdims + axis=axis, + method=method, keepdims=keepdims, unit_to_scale=unit_to_scale ) @@ -1128,9 +1142,9 @@ def quantile( a: Union[Array, Quantity], q: Union[Array, Quantity], axis: Optional[Union[int, Tuple[int]]] = None, - overwrite_input: Optional[bool] = None, method: str = 'linear', keepdims: Optional[bool] = False, + unit_to_scale: Optional[Unit] = None, ) -> Array: """ Compute the q-th percentile of the data along the specified axis. @@ -1143,8 +1157,6 @@ def quantile( Input array or Quantity. q : array_like, Quantity Percentile or sequence of percentiles to compute, which must be between 0 and 100 inclusive. - overwrite_input : bool, optional - If True, then allow the input array a to be modified by intermediate calculations, to save memory. method : str, optional This parameter specifies the method to use for estimating the percentile. There are many different methods, some unique to NumPy. @@ -1178,8 +1190,8 @@ def quantile( """ return _fun_accept_unitless_binary( jnp.quantile, a, q, - axis=axis, overwrite_input=overwrite_input, - method=method, keepdims=keepdims + axis=axis, + method=method, keepdims=keepdims, unit_to_scale=unit_to_scale ) @@ -1188,9 +1200,9 @@ def nanquantile( a: Union[Array, Quantity], q: Union[Array, Quantity], axis: Optional[Union[int, Tuple[int]]] = None, - overwrite_input: Optional[bool] = None, method: str = 'linear', keepdims: Optional[bool] = False, + unit_to_scale: Optional[Unit] = None, ) -> Array: """ Compute the q-th percentile of the data along the specified axis, while ignoring nan values. @@ -1203,8 +1215,6 @@ def nanquantile( Input array or Quantity. q : array_like, Quantity Percentile or sequence of percentiles to compute, which must be between 0 and 100 inclusive. - overwrite_input : bool, optional - If True, then allow the input array a to be modified by intermediate calculations, to save memory. method : str, optional This parameter specifies the method to use for estimating the percentile. There are many different methods, some unique to NumPy. @@ -1238,8 +1248,8 @@ def nanquantile( """ return _fun_accept_unitless_binary( jnp.nanquantile, a, q, - axis=axis, overwrite_input=overwrite_input, - method=method, keepdims=keepdims + axis=axis, + method=method, keepdims=keepdims, unit_to_scale=unit_to_scale ) diff --git a/brainunit/math/_numpy_test.py b/brainunit/math/_numpy_test.py index 89cbf4c..e582038 100644 --- a/brainunit/math/_numpy_test.py +++ b/brainunit/math/_numpy_test.py @@ -970,6 +970,10 @@ def test_exp(self): result = bu.math.exp(Quantity(jnp.array([1.0, 2.0]))) self.assertTrue(jnp.all(result == jnp.exp(jnp.array([1.0, 2.0])))) + q = [1.0, 2.0] * bu.meter + result = bu.math.exp(q, unit_to_scale=bu.dametre) + self.assertTrue(jnp.all(result == jnp.exp(jnp.array([1.0, 2.0]) / bu.dametre.value))) + def test_exp2(self): result = bu.math.exp2(jnp.array([1.0, 2.0])) self.assertTrue(jnp.all(result == jnp.exp2(jnp.array([1.0, 2.0])))) @@ -977,6 +981,10 @@ def test_exp2(self): result = bu.math.exp2(Quantity(jnp.array([1.0, 2.0]))) self.assertTrue(jnp.all(result == jnp.exp2(jnp.array([1.0, 2.0])))) + q = [1.0, 2.0] * bu.meter + result = bu.math.exp2(q, unit_to_scale=bu.dametre) + self.assertTrue(jnp.all(result == jnp.exp2(jnp.array([1.0, 2.0]) / bu.dametre.value))) + def test_expm1(self): result = bu.math.expm1(jnp.array([1.0, 2.0])) self.assertTrue(jnp.all(result == jnp.expm1(jnp.array([1.0, 2.0])))) @@ -984,6 +992,10 @@ def test_expm1(self): result = bu.math.expm1(Quantity(jnp.array([1.0, 2.0]))) self.assertTrue(jnp.all(result == jnp.expm1(jnp.array([1.0, 2.0])))) + q = [1.0, 2.0] * bu.meter + result = bu.math.expm1(q, unit_to_scale=bu.dametre) + self.assertTrue(jnp.all(result == jnp.expm1(jnp.array([1.0, 2.0]) / bu.dametre.value))) + def test_log(self): result = bu.math.log(jnp.array([1.0, 2.0])) self.assertTrue(jnp.all(result == jnp.log(jnp.array([1.0, 2.0])))) @@ -991,6 +1003,10 @@ def test_log(self): result = bu.math.log(Quantity(jnp.array([1.0, 2.0]))) self.assertTrue(jnp.all(result == jnp.log(jnp.array([1.0, 2.0])))) + q = [1.0, 2.0] * bu.meter + result = bu.math.log(q, unit_to_scale=bu.dametre) + self.assertTrue(jnp.all(result == jnp.log(jnp.array([1.0, 2.0]) / bu.dametre.value))) + def test_log10(self): result = bu.math.log10(jnp.array([1.0, 2.0])) self.assertTrue(jnp.all(result == jnp.log10(jnp.array([1.0, 2.0])))) @@ -998,6 +1014,10 @@ def test_log10(self): result = bu.math.log10(Quantity(jnp.array([1.0, 2.0]))) self.assertTrue(jnp.all(result == jnp.log10(jnp.array([1.0, 2.0])))) + q = [1.0, 2.0] * bu.meter + result = bu.math.log10(q, unit_to_scale=bu.dametre) + self.assertTrue(jnp.all(result == jnp.log10(jnp.array([1.0, 2.0]) / bu.dametre.value))) + def test_log1p(self): result = bu.math.log1p(jnp.array([1.0, 2.0])) self.assertTrue(jnp.all(result == jnp.log1p(jnp.array([1.0, 2.0])))) @@ -1005,6 +1025,10 @@ def test_log1p(self): result = bu.math.log1p(Quantity(jnp.array([1.0, 2.0]))) self.assertTrue(jnp.all(result == jnp.log1p(jnp.array([1.0, 2.0])))) + q = [1.0, 2.0] * bu.meter + result = bu.math.log1p(q, unit_to_scale=bu.dametre) + self.assertTrue(jnp.all(result == jnp.log1p(jnp.array([1.0, 2.0]) / bu.dametre.value))) + def test_log2(self): result = bu.math.log2(jnp.array([1.0, 2.0])) self.assertTrue(jnp.all(result == jnp.log2(jnp.array([1.0, 2.0])))) @@ -1012,6 +1036,10 @@ def test_log2(self): result = bu.math.log2(Quantity(jnp.array([1.0, 2.0]))) self.assertTrue(jnp.all(result == jnp.log2(jnp.array([1.0, 2.0])))) + q = [1.0, 2.0] * bu.meter + result = bu.math.log2(q, unit_to_scale=bu.dametre) + self.assertTrue(jnp.all(result == jnp.log2(jnp.array([1.0, 2.0]) / bu.dametre.value))) + def test_arccos(self): result = bu.math.arccos(jnp.array([0.5, 1.0])) self.assertTrue(jnp.all(result == jnp.arccos(jnp.array([0.5, 1.0])))) @@ -1019,6 +1047,10 @@ def test_arccos(self): result = bu.math.arccos(Quantity(jnp.array([0.5, 1.0]))) self.assertTrue(jnp.all(result == jnp.arccos(jnp.array([0.5, 1.0])))) + q = [0.5, 1.0] * bu.meter + result = bu.math.arccos(q, unit_to_scale=bu.dametre) + self.assertTrue(jnp.all(result == jnp.arccos(jnp.array([0.5, 1.0]) / bu.dametre.value))) + def test_arccosh(self): result = bu.math.arccosh(jnp.array([1.0, 2.0])) self.assertTrue(jnp.all(result == jnp.arccosh(jnp.array([1.0, 2.0])))) @@ -1026,6 +1058,10 @@ def test_arccosh(self): result = bu.math.arccosh(Quantity(jnp.array([1.0, 2.0]))) self.assertTrue(jnp.all(result == jnp.arccosh(jnp.array([1.0, 2.0])))) + q = [10., 20.] * bu.meter + result = bu.math.arccosh(q, unit_to_scale=bu.dametre) + self.assertTrue(jnp.all(result == jnp.arccosh(jnp.array([10., 20.]) / bu.dametre.value))) + def test_arcsin(self): result = bu.math.arcsin(jnp.array([0.5, 1.0])) self.assertTrue(jnp.all(result == jnp.arcsin(jnp.array([0.5, 1.0])))) @@ -1033,6 +1069,10 @@ def test_arcsin(self): result = bu.math.arcsin(Quantity(jnp.array([0.5, 1.0]))) self.assertTrue(jnp.all(result == jnp.arcsin(jnp.array([0.5, 1.0])))) + q = [0.5, 1.0] * bu.meter + result = bu.math.arcsin(q, unit_to_scale=bu.dametre) + self.assertTrue(jnp.all(result == jnp.arcsin(jnp.array([0.5, 1.0]) / bu.dametre.value))) + def test_arcsinh(self): result = bu.math.arcsinh(jnp.array([0.5, 1.0])) self.assertTrue(jnp.all(result == jnp.arcsinh(jnp.array([0.5, 1.0])))) @@ -1040,6 +1080,10 @@ def test_arcsinh(self): result = bu.math.arcsinh(Quantity(jnp.array([0.5, 1.0]))) self.assertTrue(jnp.all(result == jnp.arcsinh(jnp.array([0.5, 1.0])))) + q = [0.5, 1.0] * bu.meter + result = bu.math.arcsinh(q, unit_to_scale=bu.dametre) + self.assertTrue(jnp.all(result == jnp.arcsinh(jnp.array([0.5, 1.0]) / bu.dametre.value))) + def test_arctan(self): result = bu.math.arctan(jnp.array([0.5, 1.0])) self.assertTrue(jnp.all(result == jnp.arctan(jnp.array([0.5, 1.0])))) @@ -1047,6 +1091,10 @@ def test_arctan(self): result = bu.math.arctan(Quantity(jnp.array([0.5, 1.0]))) self.assertTrue(jnp.all(result == jnp.arctan(jnp.array([0.5, 1.0])))) + q = [0.5, 1.0] * bu.meter + result = bu.math.arctan(q, unit_to_scale=bu.dametre) + self.assertTrue(jnp.all(result == jnp.arctan(jnp.array([0.5, 1.0]) / bu.dametre.value))) + def test_arctanh(self): result = bu.math.arctanh(jnp.array([0.5, 1.0])) self.assertTrue(jnp.all(result == jnp.arctanh(jnp.array([0.5, 1.0])))) @@ -1054,6 +1102,10 @@ def test_arctanh(self): result = bu.math.arctanh(Quantity(jnp.array([0.5, 1.0]))) self.assertTrue(jnp.all(result == jnp.arctanh(jnp.array([0.5, 1.0])))) + q = [0.5, 1.0] * bu.meter + result = bu.math.arctanh(q, unit_to_scale=bu.dametre) + self.assertTrue(jnp.all(result == jnp.arctanh(jnp.array([0.5, 1.0]) / bu.dametre.value))) + def test_cos(self): result = bu.math.cos(jnp.array([0.5, 1.0])) self.assertTrue(jnp.all(result == jnp.cos(jnp.array([0.5, 1.0])))) @@ -1061,6 +1113,10 @@ def test_cos(self): result = bu.math.cos(Quantity(jnp.array([0.5, 1.0]))) self.assertTrue(jnp.all(result == jnp.cos(jnp.array([0.5, 1.0])))) + q = [0.5, 1.0] * bu.meter + result = bu.math.cos(q, unit_to_scale=bu.dametre) + self.assertTrue(jnp.all(result == jnp.cos(jnp.array([0.5, 1.0]) / bu.dametre.value))) + def test_cosh(self): result = bu.math.cosh(jnp.array([0.5, 1.0])) self.assertTrue(jnp.all(result == jnp.cosh(jnp.array([0.5, 1.0])))) @@ -1068,6 +1124,10 @@ def test_cosh(self): result = bu.math.cosh(Quantity(jnp.array([0.5, 1.0]))) self.assertTrue(jnp.all(result == jnp.cosh(jnp.array([0.5, 1.0])))) + q = [0.5, 1.0] * bu.meter + result = bu.math.cosh(q, unit_to_scale=bu.dametre) + self.assertTrue(jnp.all(result == jnp.cosh(jnp.array([0.5, 1.0]) / bu.dametre.value))) + def test_sin(self): result = bu.math.sin(jnp.array([0.5, 1.0])) self.assertTrue(jnp.all(result == jnp.sin(jnp.array([0.5, 1.0])))) @@ -1075,6 +1135,10 @@ def test_sin(self): result = bu.math.sin(Quantity(jnp.array([0.5, 1.0]))) self.assertTrue(jnp.all(result == jnp.sin(jnp.array([0.5, 1.0])))) + q = [0.5, 1.0] * bu.meter + result = bu.math.sin(q, unit_to_scale=bu.dametre) + self.assertTrue(jnp.all(result == jnp.sin(jnp.array([0.5, 1.0]) / bu.dametre.value))) + def test_sinc(self): result = bu.math.sinc(jnp.array([0.5, 1.0])) self.assertTrue(jnp.all(result == jnp.sinc(jnp.array([0.5, 1.0])))) @@ -1082,6 +1146,10 @@ def test_sinc(self): result = bu.math.sinc(Quantity(jnp.array([0.5, 1.0]))) self.assertTrue(jnp.all(result == jnp.sinc(jnp.array([0.5, 1.0])))) + q = [0.5, 1.0] * bu.meter + result = bu.math.sinc(q, unit_to_scale=bu.dametre) + self.assertTrue(jnp.all(result == jnp.sinc(jnp.array([0.5, 1.0]) / bu.dametre.value))) + def test_sinh(self): result = bu.math.sinh(jnp.array([0.5, 1.0])) self.assertTrue(jnp.all(result == jnp.sinh(jnp.array([0.5, 1.0])))) @@ -1089,6 +1157,10 @@ def test_sinh(self): result = bu.math.sinh(Quantity(jnp.array([0.5, 1.0]))) self.assertTrue(jnp.all(result == jnp.sinh(jnp.array([0.5, 1.0])))) + q = [0.5, 1.0] * bu.meter + result = bu.math.sinh(q, unit_to_scale=bu.dametre) + self.assertTrue(jnp.all(result == jnp.sinh(jnp.array([0.5, 1.0]) / bu.dametre.value))) + def test_tan(self): result = bu.math.tan(jnp.array([0.5, 1.0])) self.assertTrue(jnp.all(result == jnp.tan(jnp.array([0.5, 1.0])))) @@ -1096,6 +1168,10 @@ def test_tan(self): result = bu.math.tan(Quantity(jnp.array([0.5, 1.0]))) self.assertTrue(jnp.all(result == jnp.tan(jnp.array([0.5, 1.0])))) + q = [0.5, 1.0] * bu.meter + result = bu.math.tan(q, unit_to_scale=bu.dametre) + self.assertTrue(jnp.all(result == jnp.tan(jnp.array([0.5, 1.0]) / bu.dametre.value))) + def test_tanh(self): result = bu.math.tanh(jnp.array([0.5, 1.0])) self.assertTrue(jnp.all(result == jnp.tanh(jnp.array([0.5, 1.0])))) @@ -1103,6 +1179,10 @@ def test_tanh(self): result = bu.math.tanh(Quantity(jnp.array([0.5, 1.0]))) self.assertTrue(jnp.all(result == jnp.tanh(jnp.array([0.5, 1.0])))) + q = [0.5, 1.0] * bu.meter + result = bu.math.tanh(q, unit_to_scale=bu.dametre) + self.assertTrue(jnp.all(result == jnp.tanh(jnp.array([0.5, 1.0]) / bu.dametre.value))) + def test_deg2rad(self): result = bu.math.deg2rad(jnp.array([90.0, 180.0])) self.assertTrue(jnp.all(result == jnp.deg2rad(jnp.array([90.0, 180.0])))) @@ -1110,6 +1190,10 @@ def test_deg2rad(self): result = bu.math.deg2rad(Quantity(jnp.array([90.0, 180.0]))) self.assertTrue(jnp.all(result == jnp.deg2rad(jnp.array([90.0, 180.0])))) + q = [90.0, 180.0] * bu.meter + result = bu.math.deg2rad(q, unit_to_scale=bu.dametre) + self.assertTrue(jnp.all(result == jnp.deg2rad(jnp.array([90.0, 180.0]) / bu.dametre.value))) + def test_rad2deg(self): result = bu.math.rad2deg(jnp.array([jnp.pi / 2, jnp.pi])) self.assertTrue(jnp.all(result == jnp.rad2deg(jnp.array([jnp.pi / 2, jnp.pi])))) @@ -1117,6 +1201,10 @@ def test_rad2deg(self): result = bu.math.rad2deg(Quantity(jnp.array([jnp.pi / 2, jnp.pi]))) self.assertTrue(jnp.all(result == jnp.rad2deg(jnp.array([jnp.pi / 2, jnp.pi])))) + q = [jnp.pi / 2, jnp.pi] * bu.meter + result = bu.math.rad2deg(q, unit_to_scale=bu.dametre) + self.assertTrue(jnp.all(result == jnp.rad2deg(jnp.array([jnp.pi / 2, jnp.pi]) / bu.dametre.value))) + def test_degrees(self): result = bu.math.degrees(jnp.array([jnp.pi / 2, jnp.pi])) self.assertTrue(jnp.all(result == jnp.degrees(jnp.array([jnp.pi / 2, jnp.pi])))) @@ -1124,6 +1212,10 @@ def test_degrees(self): result = bu.math.degrees(Quantity(jnp.array([jnp.pi / 2, jnp.pi]))) self.assertTrue(jnp.all(result == jnp.degrees(jnp.array([jnp.pi / 2, jnp.pi])))) + q = [jnp.pi / 2, jnp.pi] * bu.meter + result = bu.math.degrees(q, unit_to_scale=bu.dametre) + self.assertTrue(jnp.all(result == jnp.degrees(jnp.array([jnp.pi / 2, jnp.pi]) / bu.dametre.value))) + def test_radians(self): result = bu.math.radians(jnp.array([90.0, 180.0])) self.assertTrue(jnp.all(result == jnp.radians(jnp.array([90.0, 180.0])))) @@ -1131,6 +1223,10 @@ def test_radians(self): result = bu.math.radians(Quantity(jnp.array([90.0, 180.0]))) self.assertTrue(jnp.all(result == jnp.radians(jnp.array([90.0, 180.0])))) + q = [90.0, 180.0] * bu.meter + result = bu.math.radians(q, unit_to_scale=bu.dametre) + self.assertTrue(jnp.all(result == jnp.radians(jnp.array([90.0, 180.0]) / bu.dametre.value))) + def test_angle(self): result = bu.math.angle(jnp.array([1.0 + 1.0j, 1.0 - 1.0j])) self.assertTrue(jnp.all(result == jnp.angle(jnp.array([1.0 + 1.0j, 1.0 - 1.0j])))) @@ -1138,26 +1234,52 @@ def test_angle(self): result = bu.math.angle(Quantity(jnp.array([1.0 + 1.0j, 1.0 - 1.0j]))) self.assertTrue(jnp.all(result == jnp.angle(jnp.array([1.0 + 1.0j, 1.0 - 1.0j])))) + q = [1.0 + 1.0j, 1.0 - 1.0j] * bu.meter + result = bu.math.angle(q, unit_to_scale=bu.dametre) + self.assertTrue(jnp.all(result == jnp.angle(jnp.array([1.0 + 1.0j, 1.0 - 1.0j]) / bu.dametre.value))) + def test_percentile(self): array = jnp.array([1, 2, 3, 4]) result = bu.math.percentile(array, 50) self.assertTrue(result == jnp.percentile(array, 50)) + quantity = jnp.array([1, 2, 3, 4]) * bu.meter + result = bu.math.percentile(quantity, 50 * bu.meter, unit_to_scale=bu.dametre) + self.assertTrue(result == jnp.percentile(array / bu.dametre.value, 50 / bu.dametre.value)) + def test_nanpercentile(self): array = jnp.array([1, jnp.nan, 3, 4]) result = bu.math.nanpercentile(array, 50) self.assertTrue(result == jnp.nanpercentile(array, 50)) + quantity = jnp.array([1, 2, jnp.nan, 4]) * bu.meter + result = bu.math.percentile(quantity, 50 * bu.meter, unit_to_scale=bu.dametre) + if jnp.isnan(result) and jnp.isnan(jnp.percentile(array / bu.dametre.value, 50 / bu.dametre.value)): + self.assertTrue(True) + else: + self.assertTrue(result == jnp.percentile(array / bu.dametre.value, 50 / bu.dametre.value)) + def test_quantile(self): array = jnp.array([1, 2, 3, 4]) result = bu.math.quantile(array, 0.5) self.assertTrue(result == jnp.quantile(array, 0.5)) + quantity = jnp.array([1, 2, 3, 4]) * bu.meter + result = bu.math.percentile(quantity, 0.5 * bu.meter, unit_to_scale=bu.dametre) + self.assertTrue(result == jnp.percentile(array / bu.dametre.value, 0.5 / bu.dametre.value)) + def test_nanquantile(self): array = jnp.array([1, jnp.nan, 3, 4]) result = bu.math.nanquantile(array, 0.5) self.assertTrue(result == jnp.nanquantile(array, 0.5)) + quantity = jnp.array([1, 2, jnp.nan, 4]) * bu.meter + result = bu.math.percentile(quantity, 0.5 * bu.meter, unit_to_scale=bu.dametre) + if jnp.isnan(result) and jnp.isnan(jnp.percentile(array / bu.dametre.value, 50 / bu.dametre.value)): + self.assertTrue(True) + else: + self.assertTrue(result == jnp.percentile(array / bu.dametre.value, 0.5 / bu.dametre.value)) + class TestMathFuncsOnlyAcceptUnitlessBinary(unittest.TestCase):