Skip to content

Commit

Permalink
update code
Browse files Browse the repository at this point in the history
  • Loading branch information
chaoming0625 committed Jun 27, 2024
1 parent a54759b commit 79d0be2
Showing 1 changed file with 113 additions and 163 deletions.
276 changes: 113 additions & 163 deletions brainunit/math/_fun_keep_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -2505,6 +2505,119 @@ def gcd(x1: Union[Quantity, jax.Array],
return _fun_keep_unit_binary(jnp.gcd, x1, x2)


@set_module_as('brainunit.math')
def add(
x: Union[Quantity, jax.Array],
y: Union[Quantity, jax.Array],
) -> Union[Quantity, jax.Array]:
"""
Add arguments element-wise.
Parameters
----------
x, y : array_like, Quantity
The arrays to be added.
If ``x.shape != y.shape``, they must be broadcastable to a common
shape (which becomes the shape of the output).
Returns
-------
add : ndarray or scalar
The sum of `x` and `y`, element-wise.
This is a scalar if both `x` and `y` are scalars.
"""
return _fun_keep_unit_binary(jnp.add, x, y)


@set_module_as('brainunit.math')
def subtract(
x: Union[Quantity, jax.Array],
y: Union[Quantity, jax.Array],
) -> Union[Quantity, jax.Array]:
"""
subtract(x1, x2, /, out=None, *, where=True, casting='same_kind',
order='K', dtype=None, subok=True[, signature, extobj])
Subtract arguments, element-wise.
Parameters
----------
x, y : array_like
The arrays to be subtracted from each other.
If ``x.shape != y.shape``, they must be broadcastable to a common
shape (which becomes the shape of the output).
Returns
-------
subtract : ndarray
The difference of `x` and `y`, element-wise.
This is a scalar if both `x` and `y` are scalars.
"""
return _fun_keep_unit_binary(jnp.subtract, x, y)


@set_module_as('brainunit.math')
def remainder(
x: Union[Quantity, jax.typing.ArrayLike],
y: Union[Quantity, jax.typing.ArrayLike]
) -> Union[Quantity, jax.Array]:
"""
Returns the element-wise remainder of division.
Computes the remainder complementary to the `floor_divide` function. It is
equivalent to the Python modulus operator``x1 % x2`` and has the same sign
as the divisor `x2`. The MATLAB function equivalent to ``np.remainder``
is ``mod``.
Parameters
----------
x : array_like, Quantity
Dividend array.
y : array_like, Quantity
Divisor array.
If ``x1.shape != x2.shape``, they must be broadcastable to a common
shape (which becomes the shape of the output).
Returns
-------
out : ndarray, Quantity
The element-wise remainder of the quotient ``floor_divide(x1, x2)``.
This is a scalar if both `x1` and `x2` are scalars.
This is a Quantity if division of `x1` by `x2` is not dimensionless.
"""
return _fun_keep_unit_binary(jnp.remainder, x, y)


@set_module_as('brainunit.math')
def nextafter(
x: Union[Quantity, jax.Array],
y: Union[Quantity, jax.Array],
) -> Union[Quantity, jax.Array]:
"""
nextafter(x, y, /, out=None, *, where=True, casting='same_kind',
order='K', dtype=None, subok=True[, signature, extobj])
Return the next floating-point value after x1 towards x2, element-wise.
Parameters
----------
x : array_like, Quantity
Values to find the next representable value of.
y : array_like, Quantity
The direction where to look for the next representable value of `x`.
If ``x.shape != y.shape``, they must be broadcastable to a common
shape (which becomes the shape of the output).
Returns
-------
out : ndarray or scalar
The next representable values of `x` in the direction of `y`.
This is a scalar if both `x` and `y` are scalars.
"""
return _fun_keep_unit_binary(jnp.nextafter, x, y)


# math funcs keep unit (n-ary)
# ----------------------------
@set_module_as('brainunit.math')
Expand Down Expand Up @@ -2646,169 +2759,6 @@ def histogram(
return hist, Quantity(bin_edges, dim=dim)


def _fun_match_unit_binary(func, x, y, *args, **kwargs):
if isinstance(x, Quantity) and isinstance(y, Quantity):
fail_for_dimension_mismatch(x, y, func.__name__)
return Quantity(func(x.value, y.value, *args, **kwargs), dim=x.dim)
elif isinstance(x, Quantity):
assert x.is_unitless, f'Expected unitless Quantity when y is not a Quantity, got {x}'
return func(x.value, y, *args, **kwargs)
elif isinstance(y, Quantity):
assert y.is_unitless, f'Expected unitless Quantity when x is not a Quantity, got {y}'
return func(x, y.value, *args, **kwargs)
else:
return func(x, y, *args, **kwargs)


@set_module_as('brainunit.math')
def add(
x: Union[Quantity, jax.Array],
y: Union[Quantity, jax.Array],
*args,
**kwargs
) -> Union[Quantity, jax.Array]:
"""
Add arguments element-wise.
Parameters
----------
x, y : array_like, Quantity
The arrays to be added.
If ``x.shape != y.shape``, they must be broadcastable to a common
shape (which becomes the shape of the output).
where : array_like, optional
This condition is broadcast over the input. At locations where the
condition is True, the `out` array will be set to the ufunc result.
Elsewhere, the `out` array will retain its original value.
Note that if an uninitialized `out` array is created via the default
``out=None``, locations within it where the condition is False will
remain uninitialized.
**kwargs
For other keyword-only arguments, see the
:ref:`ufunc docs <ufuncs.kwargs>`.
Returns
-------
add : ndarray or scalar
The sum of `x` and `y`, element-wise.
This is a scalar if both `x` and `y` are scalars.
"""
return _fun_match_unit_binary(jnp.add, x, y, *args, **kwargs)


@set_module_as('brainunit.math')
def subtract(
x: Union[Quantity, jax.Array],
y: Union[Quantity, jax.Array],
*args,
**kwargs
) -> Union[Quantity, jax.Array]:
"""
subtract(x1, x2, /, out=None, *, where=True, casting='same_kind',
order='K', dtype=None, subok=True[, signature, extobj])
Subtract arguments, element-wise.
Parameters
----------
x, y : array_like
The arrays to be subtracted from each other.
If ``x.shape != y.shape``, they must be broadcastable to a common
shape (which becomes the shape of the output).
where : array_like, optional
This condition is broadcast over the input. At locations where the
condition is True, the `out` array will be set to the ufunc result.
Elsewhere, the `out` array will retain its original value.
Note that if an uninitialized `out` array is created via the default
``out=None``, locations within it where the condition is False will
remain uninitialized.
**kwargs
For other keyword-only arguments, see the
:ref:`ufunc docs <ufuncs.kwargs>`.
Returns
-------
subtract : ndarray
The difference of `x` and `y`, element-wise.
This is a scalar if both `x` and `y` are scalars.
"""
return _fun_match_unit_binary(jnp.subtract, x, y, *args, **kwargs)


@set_module_as('brainunit.math')
def remainder(
x: Union[Quantity, jax.typing.ArrayLike],
y: Union[Quantity, jax.typing.ArrayLike]
) -> Union[Quantity, jax.Array]:
"""
Returns the element-wise remainder of division.
Computes the remainder complementary to the `floor_divide` function. It is
equivalent to the Python modulus operator``x1 % x2`` and has the same sign
as the divisor `x2`. The MATLAB function equivalent to ``np.remainder``
is ``mod``.
Parameters
----------
x : array_like, Quantity
Dividend array.
y : array_like, Quantity
Divisor array.
If ``x1.shape != x2.shape``, they must be broadcastable to a common
shape (which becomes the shape of the output).
Returns
-------
out : ndarray, Quantity
The element-wise remainder of the quotient ``floor_divide(x1, x2)``.
This is a scalar if both `x1` and `x2` are scalars.
This is a Quantity if division of `x1` by `x2` is not dimensionless.
"""
return _fun_match_unit_binary(jnp.remainder, x, y)


@set_module_as('brainunit.math')
def nextafter(
x: Union[Quantity, jax.Array],
y: Union[Quantity, jax.Array],
*args,
**kwargs
) -> Union[Quantity, jax.Array]:
"""
nextafter(x, y, /, out=None, *, where=True, casting='same_kind',
order='K', dtype=None, subok=True[, signature, extobj])
Return the next floating-point value after x1 towards x2, element-wise.
Parameters
----------
x : array_like, Quantity
Values to find the next representable value of.
y : array_like, Quantity
The direction where to look for the next representable value of `x`.
If ``x.shape != y.shape``, they must be broadcastable to a common
shape (which becomes the shape of the output).
where : array_like, optional
This condition is broadcast over the input. At locations where the
condition is True, the `out` array will be set to the ufunc result.
Elsewhere, the `out` array will retain its original value.
Note that if an uninitialized `out` array is created via the default
``out=None``, locations within it where the condition is False will
remain uninitialized.
**kwargs
For other keyword-only arguments, see the
:ref:`ufunc docs <ufuncs.kwargs>`.
Returns
-------
out : ndarray or scalar
The next representable values of `x` in the direction of `y`.
This is a scalar if both `x` and `y` are scalars.
"""
return _fun_match_unit_binary(jnp.nextafter, x, y, *args, **kwargs)


@set_module_as('brainunit.math')
def compress(
condition: jax.Array,
Expand Down

0 comments on commit 79d0be2

Please sign in to comment.