Skip to content

Commit

Permalink
fix functions accepet unitless that receives binary inputs (#33)
Browse files Browse the repository at this point in the history
  • Loading branch information
chaoming0625 committed Jul 7, 2024
1 parent 4cefe6b commit 52117bc
Showing 1 changed file with 13 additions and 14 deletions.
27 changes: 13 additions & 14 deletions brainunit/math/_fun_accept_unitless.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,10 @@
'arctanh', 'cos', 'cosh', 'sin', 'sinc', 'sinh', 'tan',
'tanh', 'deg2rad', 'rad2deg', 'degrees', 'radians', 'angle',


# math funcs only accept unitless (unary) can return Quantity
'round', 'around', 'round_', 'rint',
'floor', 'ceil', 'trunc', 'fix', 'modf', 'frexp',


# math funcs only accept unitless (binary)
'hypot', 'arctan2', 'logaddexp', 'logaddexp2',
'corrcoef', 'correlate', 'cov', 'ldexp',
Expand Down Expand Up @@ -73,7 +71,7 @@ def _fun_accept_unitless_unary(
)
return func(x.to_value(unit_to_scale), *args, **kwargs)
else:
assert unit_to_scale is None, f'Unit should be None for the function "{func}" when "x" is not a Quantity.'
# assert unit_to_scale is None, f'Unit should be None for the function "{func}" when "x" is not a Quantity.'
return func(x, *args, **kwargs)


Expand Down Expand Up @@ -780,7 +778,7 @@ def _fun_accept_unitless_return_keep_unit(
r = func(x.to_value(unit_to_scale), *args, **kwargs)
return jax.tree.map(lambda a: a * unit_to_scale, r)
else:
assert unit_to_scale is None, f'Unit should be None for the function "{func}" when "x" is not a Quantity.'
# assert unit_to_scale is None, f'Unit should be None for the function "{func}" when "x" is not a Quantity.'
return func(x, *args, **kwargs)


Expand Down Expand Up @@ -1000,15 +998,11 @@ def _fun_accept_unitless_binary(
unit_to_scale: Optional[Unit] = None,
**kwargs
):
if isinstance(x, Quantity) and isinstance(y, Quantity):
if isinstance(x, Quantity):
if unit_to_scale is None:
assert x.is_unitless, (f'Input should be unitless for the function "{func}" '
f'when scaling "unit_to_scale" is not provided.')
assert y.is_unitless, (f'Input should be unitless for the function "{func}" '
f'when scaling "unit_to_scale" is not provided.')
x = x.value
y = y.value
return func(x, y, *args, **kwargs)
else:
fail_for_dimension_mismatch(
x,
Expand All @@ -1017,17 +1011,22 @@ def _fun_accept_unitless_binary(
value=x,
unit_to_scale=unit_to_scale
)
x = x.to_value(unit_to_scale)
if isinstance(y, Quantity):
if unit_to_scale is None:
assert y.is_unitless, (f'Input should be unitless for the function "{func}" '
f'when scaling "unit_to_scale" is not provided.')
y = y.value
else:
fail_for_dimension_mismatch(
y,
unit_to_scale,
error_message="Unit mismatch: {value} != {unit_to_scale}",
value=y,
unit_to_scale=unit_to_scale
)
return func(x.to_value(unit_to_scale), y.to_value(unit_to_scale), *args, **kwargs)
else:
assert unit_to_scale is None, f'Unit should be None for the function "{func}" when "x" and "y" are not Quantities.'
return func(x, y, *args, **kwargs)
y = y.to_value(unit_to_scale)
return func(x, y, *args, **kwargs)


@set_module_as('brainunit.math')
Expand Down Expand Up @@ -1292,6 +1291,7 @@ def cov(
aweights=aweights, unit_to_scale=unit_to_scale
)


@set_module_as('brainunit.math')
def ldexp(
x: Union[Quantity, jax.typing.ArrayLike],
Expand Down Expand Up @@ -1499,4 +1499,3 @@ def right_shift(
Output array.
"""
return _fun_unitless_binary(jnp.right_shift, x, y)

0 comments on commit 52117bc

Please sign in to comment.