Skip to content

Commit

Permalink
Fix the logic of numpy_accept_unitless.py and Add their tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Routhleck committed Jun 26, 2024
1 parent 5576287 commit b59bbb3
Show file tree
Hide file tree
Showing 2 changed files with 180 additions and 48 deletions.
106 changes: 58 additions & 48 deletions brainunit/math/_numpy_accept_unitless.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def _func_accept_unitless_unary(
value=x,
unit_to_scale=unit_to_scale
)
return Quantity(func(x / unit_to_scale, *args, **kwargs), unit=unit_to_scale)
return func(x.to_value(unit_to_scale), *args, **kwargs)
else:
assert unit_to_scale is None, f'Unit should be None for the function "{func}" when "x" is not a Quantity.'
return func(x, *args, **kwargs)
Expand Down Expand Up @@ -889,30 +889,48 @@ def frexp(
# ----------------------------------------


def _fun_accept_unitless_binary(func, x, y, *args, **kwargs):
x_value = x.value if isinstance(x, Quantity) else x
y_value = y.value if isinstance(y, Quantity) else y
if isinstance(x, Quantity) or isinstance(y, Quantity):
fail_for_dimension_mismatch(
x,
y,
error_message="%s expects a dimensionless argument but got {value}" % func.__name__,
value=x,
)
fail_for_dimension_mismatch(
y,
error_message="%s expects a dimensionless argument but got {value}" % func.__name__,
value=y,
)
return func(jnp.array(x_value), jnp.array(y_value), *args, **kwargs)
def _fun_accept_unitless_binary(
func: Callable,
x: jax.typing.ArrayLike | Quantity,
y: jax.typing.ArrayLike | Quantity,
*args,
unit_to_scale: Optional[Unit] = None,
**kwargs):
if isinstance(x, Quantity) and isinstance(y, Quantity):
if unit_to_scale is None:
assert x.is_unitless, (f'Input should be unitless for the function "{func}" '
f'when scaling "unit_to_scale" is not provided.')
assert y.is_unitless, (f'Input should be unitless for the function "{func}" '
f'when scaling "unit_to_scale" is not provided.')
x = x.value
y = y.value
return func(x, y, *args, **kwargs)
else:
fail_for_dimension_mismatch(
x,
unit_to_scale,
error_message="Unit mismatch: {value} != {unit_to_scale}",
value=x,
unit_to_scale=unit_to_scale
)
fail_for_dimension_mismatch(
y,
unit_to_scale,
error_message="Unit mismatch: {value} != {unit_to_scale}",
value=y,
unit_to_scale=unit_to_scale
)
return func(x.to_value(unit_to_scale), y.to_value(unit_to_scale), *args, **kwargs)
else:
assert unit_to_scale is None, f'Unit should be None for the function "{func}" when "x" and "y" are not Quantities.'
return func(x, y, *args, **kwargs)


@set_module_as('brainunit.math')
def hypot(
x: Union[Array, Quantity],
y: Union[Array, Quantity],
unit_to_scale: Optional[Unit] = None,
) -> Array | Quantity:
"""
Given the “legs” of a right triangle, return its hypotenuse.
Expand All @@ -929,13 +947,14 @@ def hypot(
out : jax.Array
Output array.
"""
return _fun_accept_unitless_binary(jnp.hypot, x, y)
return _fun_accept_unitless_binary(jnp.hypot, x, y, unit_to_scale=unit_to_scale)


@set_module_as('brainunit.math')
def arctan2(
x: Union[Array, Quantity],
y: Union[Array, Quantity],
unit_to_scale: Optional[Unit] = None,
) -> Array | Quantity:
"""
Element-wise arc tangent of `x1/x2` choosing the quadrant correctly.
Expand All @@ -952,13 +971,14 @@ def arctan2(
out : jax.Array
Output array.
"""
return _fun_accept_unitless_binary(jnp.arctan2, x, y)
return _fun_accept_unitless_binary(jnp.arctan2, x, y, unit_to_scale=unit_to_scale)


@set_module_as('brainunit.math')
def logaddexp(
x: Union[Array, Quantity],
y: Union[Array, Quantity],
unit_to_scale: Optional[Unit] = None,
) -> Array | Quantity:
"""
Logarithm of the sum of exponentiations of the inputs.
Expand All @@ -975,11 +995,15 @@ def logaddexp(
out : jax.Array
Output array.
"""
return _fun_accept_unitless_binary(jnp.logaddexp, x, y)
return _fun_accept_unitless_binary(jnp.logaddexp, x, y, unit_to_scale=unit_to_scale)


@set_module_as('brainunit.math')
def logaddexp2(x: Union[Array, Quantity], y: Union[Array, Quantity]) -> Array | Quantity:
def logaddexp2(
x: Union[Array, Quantity],
y: Union[Array, Quantity],
unit_to_scale: Optional[Unit] = None,
) -> Array | Quantity:
"""
Logarithm of the sum of exponentiations of the inputs in base-2.
Expand All @@ -995,17 +1019,17 @@ def logaddexp2(x: Union[Array, Quantity], y: Union[Array, Quantity]) -> Array |
out : jax.Array
Output array.
"""
return _fun_accept_unitless_binary(jnp.logaddexp2, x, y)
return _fun_accept_unitless_binary(jnp.logaddexp2, x, y, unit_to_scale=unit_to_scale)


@set_module_as('brainunit.math')
def percentile(
a: Union[Array, Quantity],
q: Union[Array, Quantity],
axis: Optional[Union[int, Tuple[int]]] = None,
overwrite_input: Optional[bool] = None,
method: str = 'linear',
keepdims: Optional[bool] = False,
unit_to_scale: Optional[Unit] = None,
) -> Array:
"""
Compute the q-th percentile of the data along the specified axis.
Expand All @@ -1018,11 +1042,6 @@ def percentile(
Input array or Quantity.
q : array_like, Quantity
Percentile or sequence of percentiles to compute, which must be between 0 and 100 inclusive.
out : array_like, Quantity, optional
Alternative output array in which to place the result.
It must have the same shape and buffer length as the expected output but the type will be cast if necessary.
overwrite_input : bool, optional
If True, then allow the input array a to be modified by intermediate calculations, to save memory.
method : str, optional
This parameter specifies the method to use for estimating the
percentile. There are many different methods, some unique to NumPy.
Expand Down Expand Up @@ -1055,8 +1074,8 @@ def percentile(
Output array.
"""
return _fun_accept_unitless_binary(
jnp.percentile, a, q, axis=axis, overwrite_input=overwrite_input,
method=method, keepdims=keepdims
jnp.percentile, a, q, axis=axis,
method=method, keepdims=keepdims, unit_to_scale=unit_to_scale
)


Expand All @@ -1065,9 +1084,9 @@ def nanpercentile(
a: Union[Array, Quantity],
q: Union[Array, Quantity],
axis: Optional[Union[int, Tuple[int]]] = None,
overwrite_input: Optional[bool] = None,
method: str = 'linear',
keepdims: Optional[bool] = False,
unit_to_scale: Optional[Unit] = None,
) -> Array:
"""
Compute the q-th percentile of the data along the specified axis, while ignoring nan values.
Expand All @@ -1080,11 +1099,6 @@ def nanpercentile(
Input array or Quantity.
q : array_like, Quantity
Percentile or sequence of percentiles to compute, which must be between 0 and 100 inclusive.
out : array_like, Quantity, optional
Alternative output array in which to place the result.
It must have the same shape and buffer length as the expected output but the type will be cast if necessary.
overwrite_input : bool, optional
If True, then allow the input array a to be modified by intermediate calculations, to save memory.
method : str, optional
This parameter specifies the method to use for estimating the
percentile. There are many different methods, some unique to NumPy.
Expand Down Expand Up @@ -1118,8 +1132,8 @@ def nanpercentile(
"""
return _fun_accept_unitless_binary(
jnp.nanpercentile, a, q,
axis=axis, ooverwrite_input=overwrite_input,
method=method, keepdims=keepdims
axis=axis,
method=method, keepdims=keepdims, unit_to_scale=unit_to_scale
)


Expand All @@ -1128,9 +1142,9 @@ def quantile(
a: Union[Array, Quantity],
q: Union[Array, Quantity],
axis: Optional[Union[int, Tuple[int]]] = None,
overwrite_input: Optional[bool] = None,
method: str = 'linear',
keepdims: Optional[bool] = False,
unit_to_scale: Optional[Unit] = None,
) -> Array:
"""
Compute the q-th percentile of the data along the specified axis.
Expand All @@ -1143,8 +1157,6 @@ def quantile(
Input array or Quantity.
q : array_like, Quantity
Percentile or sequence of percentiles to compute, which must be between 0 and 100 inclusive.
overwrite_input : bool, optional
If True, then allow the input array a to be modified by intermediate calculations, to save memory.
method : str, optional
This parameter specifies the method to use for estimating the
percentile. There are many different methods, some unique to NumPy.
Expand Down Expand Up @@ -1178,8 +1190,8 @@ def quantile(
"""
return _fun_accept_unitless_binary(
jnp.quantile, a, q,
axis=axis, overwrite_input=overwrite_input,
method=method, keepdims=keepdims
axis=axis,
method=method, keepdims=keepdims, unit_to_scale=unit_to_scale
)


Expand All @@ -1188,9 +1200,9 @@ def nanquantile(
a: Union[Array, Quantity],
q: Union[Array, Quantity],
axis: Optional[Union[int, Tuple[int]]] = None,
overwrite_input: Optional[bool] = None,
method: str = 'linear',
keepdims: Optional[bool] = False,
unit_to_scale: Optional[Unit] = None,
) -> Array:
"""
Compute the q-th percentile of the data along the specified axis, while ignoring nan values.
Expand All @@ -1203,8 +1215,6 @@ def nanquantile(
Input array or Quantity.
q : array_like, Quantity
Percentile or sequence of percentiles to compute, which must be between 0 and 100 inclusive.
overwrite_input : bool, optional
If True, then allow the input array a to be modified by intermediate calculations, to save memory.
method : str, optional
This parameter specifies the method to use for estimating the
percentile. There are many different methods, some unique to NumPy.
Expand Down Expand Up @@ -1238,8 +1248,8 @@ def nanquantile(
"""
return _fun_accept_unitless_binary(
jnp.nanquantile, a, q,
axis=axis, overwrite_input=overwrite_input,
method=method, keepdims=keepdims
axis=axis,
method=method, keepdims=keepdims, unit_to_scale=unit_to_scale
)


Expand Down
Loading

0 comments on commit b59bbb3

Please sign in to comment.