diff --git a/brainunit/_unit_test.py b/brainunit/_unit_test.py index 4510221..19095a9 100644 --- a/brainunit/_unit_test.py +++ b/brainunit/_unit_test.py @@ -73,7 +73,7 @@ def assert_allclose(actual, desired, rtol=4.5e8, atol=0, **kwds): def assert_quantity(q, values, unit): values = jnp.asarray(values) if isinstance(q, Quantity): - assert have_same_unit(q.unit, unit), f"Dimension mismatch: ({get_unit(q)}) ({get_unit(unit)})" + assert have_same_unit(q.dim, unit), f"Dimension mismatch: ({get_unit(q)}) ({get_unit(unit)})" if not jnp.allclose(q.value, values): raise AssertionError(f"Values do not match: {q.value} != {values}") elif isinstance(q, jnp.ndarray): @@ -144,10 +144,10 @@ def test_get_dimensions(): Test various ways of getting/comparing the dimensions of a Array. """ q = 500 * ms - assert get_unit(q) is get_or_create_dimension(q.unit._dims) - assert get_unit(q) is q.unit + assert get_unit(q) is get_or_create_dimension(q.dim._dims) + assert get_unit(q) is q.dim assert q.has_same_unit(3 * second) - dims = q.unit + dims = q.dim assert_equal(dims.get_dimension("time"), 1.0) assert_equal(dims.get_dimension("length"), 0) @@ -201,47 +201,54 @@ def test_unary_operations(): def test_operations(): - q1 = Quantity(5, dim=mV) - q2 = Quantity(10, dim=mV) - assert_quantity(q1 + q2, 15, mV) - assert_quantity(q1 - q2, -5, mV) - assert_quantity(q1 * q2, 50, mV * mV) + q1 = 5 * second + q2 = 10 * second + assert_quantity(q1 + q2, 15, second) + assert_quantity(q1 - q2, -5, second) + assert_quantity(q1 * q2, 50, second * second) assert_quantity(q2 / q1, 2, DIMENSIONLESS) assert_quantity(q2 // q1, 2, DIMENSIONLESS) - assert_quantity(q2 % q1, 0, mV) + assert_quantity(q2 % q1, 0, second) assert_quantity(divmod(q2, q1)[0], 2, DIMENSIONLESS) - assert_quantity(divmod(q2, q1)[1], 0, mV) - assert_quantity(q1 ** 2, 25, mV ** 2) - assert_quantity(q1 << 1, 10, mV) - assert_quantity(q1 >> 1, 2, mV) - assert_quantity(round(q1, 0), 5, mV) + assert_quantity(divmod(q2, q1)[1], 0, second) + assert_quantity(q1 ** 2, 25, second ** 2) + assert_quantity(round(q1, 0), 5, second) + # matmul - q1 = Quantity([1, 2], dim=mV) - q2 = Quantity([3, 4], dim=mV) - assert_quantity(q1 @ q2, 11, mV ** 2) + q1 = [1, 2] * second + q2 = [3, 4] * second + assert_quantity(q1 @ q2, 11, second ** 2) + q1 = Quantity([1, 2], unit=second) + q2 = Quantity([3, 4], unit=second) + assert_quantity(q1 @ q2, 11, second ** 2) + + # shift + q1 = Quantity(0b1100, dtype=jnp.int32, dim=DIMENSIONLESS) + assert_quantity(q1 << 1, 0b11000, second) + assert_quantity(q1 >> 1, 0b110, second) def test_numpy_methods(): - q = Quantity([[1, 2], [3, 4]], dim=mV) + q = [[1, 2], [3, 4]] * second assert q.all() assert q.any() assert q.nonzero()[0].tolist() == [0, 0, 1, 1] assert q.argmax() == 3 assert q.argmin() == 0 assert q.argsort(axis=None).tolist() == [0, 1, 2, 3] - assert_quantity(q.var(), 1.25, mV ** 2) - assert_quantity(q.round(), [[1, 2], [3, 4]], mV) - assert_quantity(q.std(), 1.11803398875, mV) - assert_quantity(q.sum(), 10, mV) - assert_quantity(q.trace(), 5, mV) - assert_quantity(q.cumsum(), [1, 3, 6, 10], mV) - assert_quantity(q.cumprod(), [1, 2, 6, 24], mV ** 4) - assert_quantity(q.diagonal(), [1, 4], mV) - assert_quantity(q.max(), 4, mV) - assert_quantity(q.mean(), 2.5, mV) - assert_quantity(q.min(), 1, mV) - assert_quantity(q.ptp(), 3, mV) - assert_quantity(q.ravel(), [1, 2, 3, 4], mV) + assert_quantity(q.var(), 1.25, second ** 2) + assert_quantity(q.round(), [[1, 2], [3, 4]], second) + assert_quantity(q.std(), 1.11803398875, second) + assert_quantity(q.sum(), 10, second) + assert_quantity(q.trace(), 5, second) + assert_quantity(q.cumsum(), [1, 3, 6, 10], second) + assert_quantity(q.cumprod(), [1, 2, 6, 24], second ** 4) + assert_quantity(q.diagonal(), [1, 4], second) + assert_quantity(q.max(), 4, second) + assert_quantity(q.mean(), 2.5, second) + assert_quantity(q.min(), 1, second) + assert_quantity(q.ptp(), 3, second) + assert_quantity(q.ravel(), [1, 2, 3, 4], second) def test_shape_manipulation(): @@ -1596,7 +1603,7 @@ def test_constants(): import brainunit._unit_constants as constants # Check that the expected names exist and have the correct dimensions - assert constants.avogadro_constant.dim == (1 / mole).unit + assert constants.avogadro_constant.dim == (1 / mole).dim assert constants.boltzmann_constant.dim == (joule / kelvin).dim assert constants.electric_constant.dim == (farad / meter).dim assert constants.electron_mass.dim == kilogram.dim diff --git a/brainunit/math/__init__.py b/brainunit/math/__init__.py index 68b77d5..e574603 100644 --- a/brainunit/math/__init__.py +++ b/brainunit/math/__init__.py @@ -13,12 +13,67 @@ # limitations under the License. # ============================================================================== -from ._compat_numpy import * -from ._compat_numpy import __all__ as _compat_numpy_all +# from ._compat_numpy import * +# from ._compat_numpy import __all__ as _compat_numpy_all from ._others import * from ._others import __all__ as _other_all +from ._compat_numpy_array_creation import * +from ._compat_numpy_array_creation import __all__ as _compat_array_creation_all +from ._compat_numpy_array_manipulation import * +from ._compat_numpy_array_manipulation import __all__ as _compat_array_manipulation_all +from ._compat_numpy_funcs_accept_unitless import * +from ._compat_numpy_funcs_accept_unitless import __all__ as _compat_funcs_accept_unitless_all +from ._compat_numpy_funcs_bit_operation import * +from ._compat_numpy_funcs_bit_operation import __all__ as _compat_funcs_bit_operation_all +from ._compat_numpy_funcs_change_unit import * +from ._compat_numpy_funcs_change_unit import __all__ as _compat_funcs_change_unit_all +from ._compat_numpy_funcs_indexing import * +from ._compat_numpy_funcs_indexing import __all__ as _compat_funcs_indexing_all +from ._compat_numpy_funcs_keep_unit import * +from ._compat_numpy_funcs_keep_unit import __all__ as _compat_funcs_keep_unit_all +from ._compat_numpy_funcs_logic import * +from ._compat_numpy_funcs_logic import __all__ as _compat_funcs_logic_all +from ._compat_numpy_funcs_match_unit import * +from ._compat_numpy_funcs_match_unit import __all__ as _compat_funcs_match_unit_all +from ._compat_numpy_funcs_remove_unit import * +from ._compat_numpy_funcs_remove_unit import __all__ as _compat_funcs_remove_unit_all +from ._compat_numpy_funcs_window import * +from ._compat_numpy_funcs_window import __all__ as _compat_funcs_window_all +from ._compat_numpy_get_attribute import * +from ._compat_numpy_get_attribute import __all__ as _compat_get_attribute_all +from ._compat_numpy_linear_algebra import * +from ._compat_numpy_linear_algebra import __all__ as _compat_linear_algebra_all +from ._compat_numpy_misc import * +from ._compat_numpy_misc import __all__ as _compat_misc_all -__all__ = _compat_numpy_all + _other_all - -del _compat_numpy_all, _other_all +__all__ = _compat_array_creation_all + \ + _compat_array_manipulation_all + \ + _compat_funcs_change_unit_all + \ + _compat_funcs_keep_unit_all + \ + _compat_funcs_accept_unitless_all + \ + _compat_funcs_match_unit_all + \ + _compat_funcs_remove_unit_all + \ + _compat_get_attribute_all + \ + _compat_funcs_bit_operation_all + \ + _compat_funcs_logic_all + \ + _compat_funcs_indexing_all + \ + _compat_funcs_window_all + \ + _compat_linear_algebra_all + \ + _compat_misc_all + _other_all + \ + _other_all +del _compat_array_creation_all, \ + _compat_array_manipulation_all, \ + _compat_funcs_change_unit_all, \ + _compat_funcs_keep_unit_all, \ + _compat_funcs_accept_unitless_all, \ + _compat_funcs_match_unit_all, \ + _compat_funcs_remove_unit_all, \ + _compat_get_attribute_all, \ + _compat_funcs_bit_operation_all, \ + _compat_funcs_logic_all, \ + _compat_funcs_indexing_all, \ + _compat_funcs_window_all, \ + _compat_linear_algebra_all, \ + _compat_misc_all, \ + _other_all diff --git a/brainunit/math/_compat_numpy.py b/brainunit/math/_compat_numpy.py deleted file mode 100644 index b150455..0000000 --- a/brainunit/math/_compat_numpy.py +++ /dev/null @@ -1,1455 +0,0 @@ -# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -from collections.abc import Sequence -from functools import wraps -from typing import (Callable, Union, Optional) - -import brainstate as bst -import jax -import jax.numpy as jnp -import numpy as np -import opt_einsum -from brainstate._utils import set_module_as -from jax._src.numpy.lax_numpy import _einsum - -from ._utils import _compatible_with_quantity -from .._base import (DIMENSIONLESS, - Quantity, - Unit, - fail_for_dimension_mismatch, - is_unitless, - get_unit, ) -from .._base import _return_check_unitless - -__all__ = [ - # array creation - 'full', 'full_like', 'eye', 'identity', 'diag', 'tri', 'tril', 'triu', - 'empty', 'empty_like', 'ones', 'ones_like', 'zeros', 'zeros_like', - 'array', 'asarray', 'arange', 'linspace', 'logspace', 'fill_diagonal', - 'array_split', 'meshgrid', 'vander', - - # getting attribute funcs - 'ndim', 'isreal', 'isscalar', 'isfinite', 'isinf', - 'isnan', 'shape', 'size', - - # math funcs keep unit (unary) - 'real', 'imag', 'conj', 'conjugate', 'negative', 'positive', - 'abs', 'round', 'around', 'round_', 'rint', - 'floor', 'ceil', 'trunc', 'fix', 'sum', 'nancumsum', 'nansum', - 'cumsum', 'ediff1d', 'absolute', 'fabs', 'median', - 'nanmin', 'nanmax', 'ptp', 'average', 'mean', 'std', - 'nanmedian', 'nanmean', 'nanstd', 'diff', 'modf', - - # math funcs keep unit (binary) - 'fmod', 'mod', 'copysign', 'heaviside', - 'maximum', 'minimum', 'fmax', 'fmin', 'lcm', 'gcd', - - # math funcs keep unit (n-ary) - 'interp', 'clip', - - # math funcs match unit (binary) - 'add', 'subtract', 'nextafter', - - # math funcs change unit (unary) - 'reciprocal', 'prod', 'product', 'nancumprod', 'nanprod', 'cumprod', - 'cumproduct', 'var', 'nanvar', 'cbrt', 'square', 'frexp', 'sqrt', - - # math funcs change unit (binary) - 'multiply', 'divide', 'power', 'cross', 'ldexp', - 'true_divide', 'floor_divide', 'float_power', - 'divmod', 'remainder', 'convolve', - - # math funcs only accept unitless (unary) - 'exp', 'exp2', 'expm1', 'log', 'log10', 'log1p', 'log2', - 'arccos', 'arccosh', 'arcsin', 'arcsinh', 'arctan', - 'arctanh', 'cos', 'cosh', 'sin', 'sinc', 'sinh', 'tan', - 'tanh', 'deg2rad', 'rad2deg', 'degrees', 'radians', 'angle', - 'percentile', 'nanpercentile', 'quantile', 'nanquantile', - - # math funcs only accept unitless (binary) - 'hypot', 'arctan2', 'logaddexp', 'logaddexp2', - - # math funcs remove unit (unary) - 'signbit', 'sign', 'histogram', 'bincount', - - # math funcs remove unit (binary) - 'corrcoef', 'correlate', 'cov', 'digitize', - - # array manipulation - 'reshape', 'moveaxis', 'transpose', 'swapaxes', 'row_stack', - 'concatenate', 'stack', 'vstack', 'hstack', 'dstack', 'column_stack', - 'split', 'dsplit', 'hsplit', 'vsplit', 'tile', 'repeat', 'unique', - 'append', 'flip', 'fliplr', 'flipud', 'roll', 'atleast_1d', 'atleast_2d', - 'atleast_3d', 'expand_dims', 'squeeze', 'sort', 'argsort', 'argmax', 'argmin', - 'argwhere', 'nonzero', 'flatnonzero', 'searchsorted', 'extract', - 'count_nonzero', 'max', 'min', 'amax', 'amin', 'block', 'compress', - 'diagflat', 'diagonal', 'choose', 'ravel', - - # Elementwise bit operations (unary) - 'bitwise_not', 'invert', 'left_shift', 'right_shift', - - # Elementwise bit operations (binary) - 'bitwise_and', 'bitwise_or', 'bitwise_xor', - - # logic funcs (unary) - 'all', 'any', 'logical_not', - - # logic funcs (binary) - 'equal', 'not_equal', 'greater', 'greater_equal', 'less', 'less_equal', - 'array_equal', 'isclose', 'allclose', 'logical_and', - 'logical_or', 'logical_xor', "alltrue", 'sometrue', - - # indexing funcs - 'nonzero', 'where', 'tril_indices', 'tril_indices_from', 'triu_indices', - 'triu_indices_from', 'take', 'select', - - # window funcs - 'bartlett', 'blackman', 'hamming', 'hanning', 'kaiser', - - # constants - 'e', 'pi', 'inf', - - # linear algebra - 'dot', 'vdot', 'inner', 'outer', 'kron', 'matmul', 'trace', - - # data types - 'dtype', 'finfo', 'iinfo', - - # more - 'broadcast_arrays', 'broadcast_shapes', - 'einsum', 'gradient', 'intersect1d', 'nan_to_num', 'nanargmax', 'nanargmin', - 'rot90', 'tensordot', - -] - - -# array creation -# -------------- - -def wrap_array_creation_function(func): - def f(*args, unit: Unit = None, **kwargs): - if unit is not None: - assert isinstance(unit, Unit), f'unit must be an instance of Unit, got {type(unit)}' - return func(*args, **kwargs) * unit - else: - return func(*args, **kwargs) - - f.__module__ = 'brainunit.math' - return f - - -# array creation -# -------------- - -full = wrap_array_creation_function(jnp.full) -eye = wrap_array_creation_function(jnp.eye) -identity = wrap_array_creation_function(jnp.identity) -tri = wrap_array_creation_function(jnp.tri) -empty = wrap_array_creation_function(jnp.empty) -ones = wrap_array_creation_function(jnp.ones) -zeros = wrap_array_creation_function(jnp.zeros) - - -@set_module_as('brainunit.math') -def full_like(a, fill_value, dtype=None, shape=None): - if isinstance(a, Quantity) and isinstance(fill_value, Quantity): - fail_for_dimension_mismatch(a, fill_value, error_message='Units do not match for full_like operation.') - return Quantity(jnp.full_like(a.value, fill_value.value, dtype=dtype, shape=shape), dim=a.unit) - elif isinstance(a, (jax.Array, np.ndarray)) and not isinstance(fill_value, Quantity): - return jnp.full_like(a, fill_value, dtype=dtype, shape=shape) - else: - raise ValueError(f'Unsupported types : {type(a)} abd {type(fill_value)} for full_like') - - -@set_module_as('brainunit.math') -def diag(a, k=0): - if isinstance(a, Quantity): - return Quantity(jnp.diag(a.value, k=k), dim=a.unit) - elif isinstance(a, (jax.Array, np.ndarray)): - return jnp.diag(a, k=k) - else: - raise ValueError(f'Unsupported type: {type(a)} for diag') - - -@set_module_as('brainunit.math') -def tril(a, k=0): - if isinstance(a, Quantity): - return Quantity(jnp.tril(a.value, k=k), dim=a.unit) - elif isinstance(a, (jax.Array, np.ndarray)): - return jnp.tril(a, k=k) - else: - raise ValueError(f'Unsupported type: {type(a)} for tril') - - -@set_module_as('brainunit.math') -def triu(a, k=0): - if isinstance(a, Quantity): - return Quantity(jnp.triu(a.value, k=k), dim=a.unit) - elif isinstance(a, (jax.Array, np.ndarray)): - return jnp.triu(a, k=k) - else: - raise ValueError(f'Unsupported type: {type(a)} for triu') - - -@set_module_as('brainunit.math') -def empty_like(a, dtype=None, shape=None): - if isinstance(a, Quantity): - return Quantity(jnp.empty_like(a.value, dtype=dtype, shape=shape), dim=a.unit) - elif isinstance(a, (jax.Array, np.ndarray)): - return jnp.empty_like(a, dtype=dtype, shape=shape) - else: - raise ValueError(f'Unsupported type: {type(a)} for empty_like') - - -@set_module_as('brainunit.math') -def ones_like(a, dtype=None, shape=None): - if isinstance(a, Quantity): - return Quantity(jnp.ones_like(a.value, dtype=dtype, shape=shape), dim=a.unit) - elif isinstance(a, (jax.Array, np.ndarray)): - return jnp.ones_like(a, dtype=dtype, shape=shape) - else: - raise ValueError(f'Unsupported type: {type(a)} for ones_like') - - -@set_module_as('brainunit.math') -def zeros_like(a, dtype=None, shape=None): - if isinstance(a, Quantity): - return Quantity(jnp.zeros_like(a.value, dtype=dtype, shape=shape), dim=a.unit) - elif isinstance(a, (jax.Array, np.ndarray)): - return jnp.zeros_like(a, dtype=dtype, shape=shape) - else: - raise ValueError(f'Unsupported type: {type(a)} for zeros_like') - - -@set_module_as('brainunit.math') -def asarray( - a, - dtype: Optional[bst.typing.DTypeLike] = None, - order: Optional[str] = None, - unit: Optional[Unit] = None, -): - from builtins import all as origin_all - from builtins import any as origin_any - if isinstance(a, Quantity): - return Quantity(jnp.asarray(a.value, dtype=dtype, order=order), dim=a.unit) - elif isinstance(a, (jax.Array, np.ndarray)): - return jnp.asarray(a, dtype=dtype, order=order) - # list[Quantity] - elif isinstance(a, Sequence) and origin_all(isinstance(x, Quantity) for x in a): - # check all elements have the same unit - if origin_any(x.dim != a[0].dim for x in a): - raise ValueError('Units do not match for asarray operation.') - values = [x.value for x in a] - unit = a[0].dim - # Convert the values to a jnp.ndarray and create a Quantity object - return Quantity(jnp.asarray(values, dtype=dtype, order=order), dim=unit) - else: - return jnp.asarray(a, dtype=dtype, order=order) - - -array = asarray - - -@set_module_as('brainunit.math') -def arange(*args, **kwargs): - # arange has a bit of a complicated argument structure unfortunately - # we leave the actual checking of the number of arguments to numpy, though - - # default values - start = kwargs.pop("start", 0) - step = kwargs.pop("step", 1) - stop = kwargs.pop("stop", None) - if len(args) == 1: - if stop is not None: - raise TypeError("Duplicate definition of 'stop'") - stop = args[0] - elif len(args) == 2: - if start != 0: - raise TypeError("Duplicate definition of 'start'") - if stop is not None: - raise TypeError("Duplicate definition of 'stop'") - start, stop = args - elif len(args) == 3: - if start != 0: - raise TypeError("Duplicate definition of 'start'") - if stop is not None: - raise TypeError("Duplicate definition of 'stop'") - if step != 1: - raise TypeError("Duplicate definition of 'step'") - start, stop, step = args - elif len(args) > 3: - raise TypeError("Need between 1 and 3 non-keyword arguments") - - if stop is None: - raise TypeError("Missing stop argument.") - if stop is not None and not is_unitless(stop): - start = Quantity(start, dim=stop.unit) - - fail_for_dimension_mismatch( - start, - stop, - error_message=( - "Start value {start} and stop value {stop} have to have the same units." - ), - start=start, - stop=stop, - ) - fail_for_dimension_mismatch( - stop, - step, - error_message=( - "Stop value {stop} and step value {step} have to have the same units." - ), - stop=stop, - step=step, - ) - unit = getattr(stop, "unit", DIMENSIONLESS) - # start is a position-only argument in numpy 2.0 - # https://numpy.org/devdocs/release/2.0.0-notes.html#arange-s-start-argument-is-positional-only - # TODO: check whether this is still the case in the final release - if start == 0: - return Quantity( - jnp.arange( - start=start.value if isinstance(start, Quantity) else jnp.asarray(start), - stop=stop.value if isinstance(stop, Quantity) else jnp.asarray(stop), - step=step.value if isinstance(step, Quantity) else jnp.asarray(step), - **kwargs, - ), - dim=unit, - ) - else: - return Quantity( - jnp.arange( - start.value if isinstance(start, Quantity) else jnp.asarray(start), - stop=stop.value if isinstance(stop, Quantity) else jnp.asarray(stop), - step=step.value if isinstance(step, Quantity) else jnp.asarray(step), - **kwargs, - ), - dim=unit, - ) - - -@set_module_as('brainunit.math') -def linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None): - fail_for_dimension_mismatch( - start, - stop, - error_message="Start value {start} and stop value {stop} have to have the same units.", - start=start, - stop=stop, - ) - unit = getattr(start, "unit", DIMENSIONLESS) - start = start.value if isinstance(start, Quantity) else start - stop = stop.value if isinstance(stop, Quantity) else stop - - result = jnp.linspace(start, stop, num=num, endpoint=endpoint, retstep=retstep, dtype=dtype) - return Quantity(result, dim=unit) - - -@set_module_as('brainunit.math') -def logspace(start, stop, num=50, endpoint=True, base=10.0, dtype=None): - fail_for_dimension_mismatch( - start, - stop, - error_message="Start value {start} and stop value {stop} have to have the same units.", - start=start, - stop=stop, - ) - unit = getattr(start, "unit", DIMENSIONLESS) - start = start.value if isinstance(start, Quantity) else start - stop = stop.value if isinstance(stop, Quantity) else stop - - result = jnp.logspace(start, stop, num=num, endpoint=endpoint, base=base, dtype=dtype) - return Quantity(result, dim=unit) - - -@set_module_as('brainunit.math') -def fill_diagonal(a, val, wrap=False, inplace=True): - if isinstance(a, Quantity) and isinstance(val, Quantity): - fail_for_dimension_mismatch(a, val) - return Quantity(jnp.fill_diagonal(a.value, val.value, wrap=wrap, inplace=inplace), dim=a.unit) - elif isinstance(a, (jax.Array, np.ndarray)) and isinstance(val, (jax.Array, np.ndarray)): - return jnp.fill_diagonal(a, val, wrap=wrap, inplace=inplace) - elif is_unitless(a) or is_unitless(val): - return jnp.fill_diagonal(a, val, wrap=wrap, inplace=inplace) - else: - raise ValueError(f'Unsupported types : {type(a)} abd {type(val)} for fill_diagonal') - - -@set_module_as('brainunit.math') -def array_split(ary, indices_or_sections, axis=0): - if isinstance(ary, Quantity): - return Quantity(jnp.array_split(ary.value, indices_or_sections, axis), dim=ary.unit) - elif isinstance(ary, (jax.Array, np.ndarray)): - return jnp.array_split(ary, indices_or_sections, axis) - else: - raise ValueError(f'Unsupported type: {type(ary)} for array_split') - - -@set_module_as('brainunit.math') -def meshgrid(*xi, copy=True, sparse=False, indexing='xy'): - from builtins import all as origin_all - if origin_all(isinstance(x, Quantity) for x in xi): - fail_for_dimension_mismatch(*xi) - return Quantity(jnp.meshgrid(*[x.value for x in xi], copy=copy, sparse=sparse, indexing=indexing), dim=xi[0].dim) - elif origin_all(isinstance(x, (jax.Array, np.ndarray)) for x in xi): - return jnp.meshgrid(*xi, copy=copy, sparse=sparse, indexing=indexing) - else: - raise ValueError(f'Unsupported types : {type(xi)} for meshgrid') - - -@set_module_as('brainunit.math') -def vander(x, N=None, increasing=False): - if isinstance(x, Quantity): - return Quantity(jnp.vander(x.value, N=N, increasing=increasing), dim=x.unit) - elif isinstance(x, (jax.Array, np.ndarray)): - return jnp.vander(x, N=N, increasing=increasing) - else: - raise ValueError(f'Unsupported type: {type(x)} for vander') - - -# getting attribute funcs -# ----------------------- - -@set_module_as('brainunit.math') -def ndim(a): - if isinstance(a, Quantity): - return a.ndim - else: - return jnp.ndim(a) - - -@set_module_as('brainunit.math') -def isreal(a): - if isinstance(a, Quantity): - return a.isreal - else: - return jnp.isreal(a) - - -@set_module_as('brainunit.math') -def isscalar(a): - if isinstance(a, Quantity): - return a.isscalar - else: - return jnp.isscalar(a) - - -@set_module_as('brainunit.math') -def isfinite(a): - if isinstance(a, Quantity): - return a.isfinite - else: - return jnp.isfinite(a) - - -@set_module_as('brainunit.math') -def isinf(a): - if isinstance(a, Quantity): - return a.isinf - else: - return jnp.isinf(a) - - -@set_module_as('brainunit.math') -def isnan(a): - if isinstance(a, Quantity): - return a.isnan - else: - return jnp.isnan(a) - - -@set_module_as('brainunit.math') -def shape(a): - """ - Return the shape of an array. - - Parameters - ---------- - a : array_like - Input array. - - Returns - ------- - shape : tuple of ints - The elements of the shape tuple give the lengths of the - corresponding array dimensions. - - See Also - -------- - len : ``len(a)`` is equivalent to ``np.shape(a)[0]`` for N-D arrays with - ``N>=1``. - ndarray.shape : Equivalent array method. - - Examples - -------- - >>> brainunit.math.shape(brainunit.math.eye(3)) - (3, 3) - >>> brainunit.math.shape([[1, 3]]) - (1, 2) - >>> brainunit.math.shape([0]) - (1,) - >>> brainunit.math.shape(0) - () - - """ - if isinstance(a, (Quantity, jax.Array, np.ndarray)): - return a.shape - else: - return np.shape(a) - - -@set_module_as('brainunit.math') -def size(a, axis=None): - """ - Return the number of elements along a given axis. - - Parameters - ---------- - a : array_like - Input data. - axis : int, optional - Axis along which the elements are counted. By default, give - the total number of elements. - - Returns - ------- - element_count : int - Number of elements along the specified axis. - - See Also - -------- - shape : dimensions of array - Array.shape : dimensions of array - Array.size : number of elements in array - - Examples - -------- - >>> a = Quantity([[1,2,3], [4,5,6]]) - >>> brainunit.math.size(a) - 6 - >>> brainunit.math.size(a, 1) - 3 - >>> brainunit.math.size(a, 0) - 2 - """ - if isinstance(a, (Quantity, jax.Array, np.ndarray)): - if axis is None: - return a.size - else: - return a.shape[axis] - else: - return np.size(a, axis=axis) - - -# math funcs keep unit (unary) -# ---------------------------- - -def wrap_math_funcs_keep_unit_unary(func): - def f(x, *args, **kwargs): - if isinstance(x, Quantity): - return Quantity(func(x.value, *args, **kwargs), dim=x.unit) - elif isinstance(x, (jax.Array, np.ndarray)): - return func(x, *args, **kwargs) - else: - raise ValueError(f'Unsupported type: {type(x)} for {func.__name__}') - - f.__module__ = 'brainunit.math' - return f - - -real = wrap_math_funcs_keep_unit_unary(jnp.real) -imag = wrap_math_funcs_keep_unit_unary(jnp.imag) -conj = wrap_math_funcs_keep_unit_unary(jnp.conj) -conjugate = wrap_math_funcs_keep_unit_unary(jnp.conjugate) -negative = wrap_math_funcs_keep_unit_unary(jnp.negative) -positive = wrap_math_funcs_keep_unit_unary(jnp.positive) -abs = wrap_math_funcs_keep_unit_unary(jnp.abs) -round_ = wrap_math_funcs_keep_unit_unary(jnp.round) -around = wrap_math_funcs_keep_unit_unary(jnp.around) -round = wrap_math_funcs_keep_unit_unary(jnp.round) -rint = wrap_math_funcs_keep_unit_unary(jnp.rint) -floor = wrap_math_funcs_keep_unit_unary(jnp.floor) -ceil = wrap_math_funcs_keep_unit_unary(jnp.ceil) -trunc = wrap_math_funcs_keep_unit_unary(jnp.trunc) -fix = wrap_math_funcs_keep_unit_unary(jnp.fix) -sum = wrap_math_funcs_keep_unit_unary(jnp.sum) -nancumsum = wrap_math_funcs_keep_unit_unary(jnp.nancumsum) -nansum = wrap_math_funcs_keep_unit_unary(jnp.nansum) -cumsum = wrap_math_funcs_keep_unit_unary(jnp.cumsum) -ediff1d = wrap_math_funcs_keep_unit_unary(jnp.ediff1d) -absolute = wrap_math_funcs_keep_unit_unary(jnp.absolute) -fabs = wrap_math_funcs_keep_unit_unary(jnp.fabs) -median = wrap_math_funcs_keep_unit_unary(jnp.median) -nanmin = wrap_math_funcs_keep_unit_unary(jnp.nanmin) -nanmax = wrap_math_funcs_keep_unit_unary(jnp.nanmax) -ptp = wrap_math_funcs_keep_unit_unary(jnp.ptp) -average = wrap_math_funcs_keep_unit_unary(jnp.average) -mean = wrap_math_funcs_keep_unit_unary(jnp.mean) -std = wrap_math_funcs_keep_unit_unary(jnp.std) -nanmedian = wrap_math_funcs_keep_unit_unary(jnp.nanmedian) -nanmean = wrap_math_funcs_keep_unit_unary(jnp.nanmean) -nanstd = wrap_math_funcs_keep_unit_unary(jnp.nanstd) -diff = wrap_math_funcs_keep_unit_unary(jnp.diff) -modf = wrap_math_funcs_keep_unit_unary(jnp.modf) - - -# math funcs keep unit (binary) -# ----------------------------- - -def wrap_math_funcs_keep_unit_binary(func): - def f(x1, x2, *args, **kwargs): - if isinstance(x1, Quantity) and isinstance(x2, Quantity): - return Quantity(func(x1.value, x2.value, *args, **kwargs), dim=x1.unit) - elif isinstance(x1, (jax.Array, np.ndarray)) and isinstance(x2, (jax.Array, np.ndarray)): - return func(x1, x2, *args, **kwargs) - else: - raise ValueError(f'Unsupported type: {type(x1)} and {type(x2)} for {func.__name__}') - - f.__module__ = 'brainunit.math' - return f - - -fmod = wrap_math_funcs_keep_unit_binary(jnp.fmod) -mod = wrap_math_funcs_keep_unit_binary(jnp.mod) -copysign = wrap_math_funcs_keep_unit_binary(jnp.copysign) -heaviside = wrap_math_funcs_keep_unit_binary(jnp.heaviside) -maximum = wrap_math_funcs_keep_unit_binary(jnp.maximum) -minimum = wrap_math_funcs_keep_unit_binary(jnp.minimum) -fmax = wrap_math_funcs_keep_unit_binary(jnp.fmax) -fmin = wrap_math_funcs_keep_unit_binary(jnp.fmin) -lcm = wrap_math_funcs_keep_unit_binary(jnp.lcm) -gcd = wrap_math_funcs_keep_unit_binary(jnp.gcd) - - -# math funcs keep unit (n-ary) -# ---------------------------- -@set_module_as('brainunit.math') -def interp(x, xp, fp, left=None, right=None, period=None): - unit = None - if isinstance(x, Quantity) or isinstance(xp, Quantity) or isinstance(fp, Quantity): - unit = x.unit if isinstance(x, Quantity) else xp.unit if isinstance(xp, Quantity) else fp.unit - if isinstance(x, Quantity): - x_value = x.value - else: - x_value = x - if isinstance(xp, Quantity): - xp_value = xp.value - else: - xp_value = xp - if isinstance(fp, Quantity): - fp_value = fp.value - else: - fp_value = fp - result = jnp.interp(x_value, xp_value, fp_value, left=left, right=right, period=period) - if unit is not None: - return Quantity(result, dim=unit) - else: - return result - - -@set_module_as('brainunit.math') -def clip(a, a_min, a_max): - unit = None - if isinstance(a, Quantity) or isinstance(a_min, Quantity) or isinstance(a_max, Quantity): - unit = a.unit if isinstance(a, Quantity) else a_min.unit if isinstance(a_min, Quantity) else a_max.unit - if isinstance(a, Quantity): - a_value = a.value - else: - a_value = a - if isinstance(a_min, Quantity): - a_min_value = a_min.value - else: - a_min_value = a_min - if isinstance(a_max, Quantity): - a_max_value = a_max.value - else: - a_max_value = a_max - result = jnp.clip(a_value, a_min_value, a_max_value) - if unit is not None: - return Quantity(result, dim=unit) - else: - return result - - -# math funcs match unit (binary) -# ------------------------------ - -def wrap_math_funcs_match_unit_binary(func): - def f(x, y, *args, **kwargs): - if isinstance(x, Quantity) and isinstance(y, Quantity): - fail_for_dimension_mismatch(x, y) - return Quantity(func(x.value, y.value, *args, **kwargs), dim=x.unit) - elif isinstance(x, (jax.Array, np.ndarray)) and isinstance(y, (jax.Array, np.ndarray)): - return func(x, y, *args, **kwargs) - elif isinstance(x, Quantity): - if x.is_unitless: - return Quantity(func(x.value, y, *args, **kwargs), dim=x.unit) - else: - raise ValueError(f'Unsupported types : {type(x)} abd {type(y)} for {func.__name__}') - elif isinstance(y, Quantity): - if y.is_unitless: - return Quantity(func(x, y.value, *args, **kwargs), dim=y.unit) - else: - raise ValueError(f'Unsupported types : {type(x)} abd {type(y)} for {func.__name__}') - else: - raise ValueError(f'Unsupported types : {type(x)} abd {type(y)} for {func.__name__}') - - f.__module__ = 'brainunit.math' - return f - - -add = wrap_math_funcs_match_unit_binary(jnp.add) -subtract = wrap_math_funcs_match_unit_binary(jnp.subtract) -nextafter = wrap_math_funcs_match_unit_binary(jnp.nextafter) - - -# math funcs change unit (unary) -# ------------------------------ - -def wrap_math_funcs_change_unit_unary(func, change_unit_func): - def f(x, *args, **kwargs): - if isinstance(x, Quantity): - return _return_check_unitless(Quantity(func(x.value, *args, **kwargs), dim=change_unit_func(x.unit))) - elif isinstance(x, (jax.Array, np.ndarray)): - return func(x, *args, **kwargs) - else: - raise ValueError(f'Unsupported type: {type(x)} for {func.__name__}') - - f.__module__ = 'brainunit.math' - return f - - -reciprocal = wrap_math_funcs_change_unit_unary(jnp.reciprocal, lambda x: x ** -1) - - -@set_module_as('brainunit.math') -def prod(x, axis=None, dtype=None, out=None, keepdims=False, initial=None): - if isinstance(x, Quantity): - return x.prod(axis=axis, dtype=dtype, out=out, keepdims=keepdims, initial=initial) - else: - return jnp.prod(x, axis=axis, dtype=dtype, out=out, keepdims=keepdims, initial=initial) - - -@set_module_as('brainunit.math') -def nanprod(x, axis=None, dtype=None, out=None, keepdims=False, initial=None): - if isinstance(x, Quantity): - return x.nanprod(axis=axis, dtype=dtype, out=out, keepdims=keepdims, initial=initial) - else: - return jnp.nanprod(x, axis=axis, dtype=dtype, out=out, keepdims=keepdims, initial=initial) - - -product = prod - - -@set_module_as('brainunit.math') -def cumprod(x, axis=None, dtype=None, out=None): - if isinstance(x, Quantity): - return x.cumprod(axis=axis, dtype=dtype, out=out) - else: - return jnp.cumprod(x, axis=axis, dtype=dtype, out=out) - - -@set_module_as('brainunit.math') -def nancumprod(x, axis=None, dtype=None, out=None): - if isinstance(x, Quantity): - return x.nancumprod(axis=axis, dtype=dtype, out=out) - else: - return jnp.nancumprod(x, axis=axis, dtype=dtype, out=out) - - -cumproduct = cumprod - -var = wrap_math_funcs_change_unit_unary(jnp.var, lambda x: x ** 2) -nanvar = wrap_math_funcs_change_unit_unary(jnp.nanvar, lambda x: x ** 2) -frexp = wrap_math_funcs_change_unit_unary(jnp.frexp, lambda x, y: x * 2 ** y) -sqrt = wrap_math_funcs_change_unit_unary(jnp.sqrt, lambda x: x ** 0.5) -cbrt = wrap_math_funcs_change_unit_unary(jnp.cbrt, lambda x: x ** (1 / 3)) -square = wrap_math_funcs_change_unit_unary(jnp.square, lambda x: x ** 2) - - -# math funcs change unit (binary) -# ------------------------------- - -def wrap_math_funcs_change_unit_binary(func, change_unit_func): - def f(x, y, *args, **kwargs): - if isinstance(x, Quantity) and isinstance(y, Quantity): - return _return_check_unitless( - Quantity(func(x.value, y.value, *args, **kwargs), dim=change_unit_func(x.unit, y.unit)) - ) - elif isinstance(x, (jax.Array, np.ndarray)) and isinstance(y, (jax.Array, np.ndarray)): - return func(x, y, *args, **kwargs) - elif isinstance(x, Quantity): - return _return_check_unitless( - Quantity(func(x.value, y, *args, **kwargs), dim=change_unit_func(x.unit, DIMENSIONLESS))) - elif isinstance(y, Quantity): - return _return_check_unitless( - Quantity(func(x, y.value, *args, **kwargs), dim=change_unit_func(DIMENSIONLESS, y.unit))) - else: - raise ValueError(f'Unsupported types : {type(x)} abd {type(y)} for {func.__name__}') - - f.__module__ = 'brainunit.math' - return f - - -multiply = wrap_math_funcs_change_unit_binary(jnp.multiply, lambda x, y: x * y) -divide = wrap_math_funcs_change_unit_binary(jnp.divide, lambda x, y: x / y) - - -@set_module_as('brainunit.math') -def power(x, y, *args, **kwargs): - if isinstance(x, Quantity) and isinstance(y, Quantity): - return _return_check_unitless(Quantity(jnp.power(x.value, y.value, *args, **kwargs), dim=x.unit ** y.unit)) - elif isinstance(x, (jax.Array, np.ndarray)) and isinstance(y, (jax.Array, np.ndarray)): - return jnp.power(x, y, *args, **kwargs) - elif isinstance(x, Quantity): - return _return_check_unitless(Quantity(jnp.power(x.value, y, *args, **kwargs), dim=x.unit ** y)) - elif isinstance(y, Quantity): - return _return_check_unitless(Quantity(jnp.power(x, y.value, *args, **kwargs), dim=x ** y.unit)) - else: - raise ValueError(f'Unsupported types : {type(x)} abd {type(y)} for {jnp.power.__name__}') - - -cross = wrap_math_funcs_change_unit_binary(jnp.cross, lambda x, y: x * y) -ldexp = wrap_math_funcs_change_unit_binary(jnp.ldexp, lambda x, y: x * 2 ** y) -true_divide = wrap_math_funcs_change_unit_binary(jnp.true_divide, lambda x, y: x / y) - - -@set_module_as('brainunit.math') -def floor_divide(x, y, *args, **kwargs): - if isinstance(x, Quantity) and isinstance(y, Quantity): - return _return_check_unitless(Quantity(jnp.floor_divide(x.value, y.value, *args, **kwargs), dim=x.unit / y.unit)) - elif isinstance(x, (jax.Array, np.ndarray)) and isinstance(y, (jax.Array, np.ndarray)): - return jnp.floor_divide(x, y, *args, **kwargs) - elif isinstance(x, Quantity): - return _return_check_unitless(Quantity(jnp.floor_divide(x.value, y, *args, **kwargs), dim=x.unit / y)) - elif isinstance(y, Quantity): - return _return_check_unitless(Quantity(jnp.floor_divide(x, y.value, *args, **kwargs), dim=x / y.unit)) - else: - raise ValueError(f'Unsupported types : {type(x)} abd {type(y)} for {jnp.floor_divide.__name__}') - - -@set_module_as('brainunit.math') -def float_power(x, y, *args, **kwargs): - if isinstance(x, Quantity) and isinstance(y, Quantity): - return _return_check_unitless(Quantity(jnp.float_power(x.value, y.value, *args, **kwargs), dim=x.unit ** y.unit)) - elif isinstance(x, (jax.Array, np.ndarray)) and isinstance(y, (jax.Array, np.ndarray)): - return jnp.float_power(x, y, *args, **kwargs) - elif isinstance(x, Quantity): - return _return_check_unitless(Quantity(jnp.float_power(x.value, y, *args, **kwargs), dim=x.unit ** y)) - elif isinstance(y, Quantity): - return _return_check_unitless(Quantity(jnp.float_power(x, y.value, *args, **kwargs), dim=x ** y.unit)) - else: - raise ValueError(f'Unsupported types : {type(x)} abd {type(y)} for {jnp.float_power.__name__}') - - -divmod = wrap_math_funcs_change_unit_binary(jnp.divmod, lambda x, y: x / y) - - -@set_module_as('brainunit.math') -def remainder(x, y, *args, **kwargs): - if isinstance(x, Quantity) and isinstance(y, Quantity): - return _return_check_unitless(Quantity(jnp.remainder(x.value, y.value, *args, **kwargs), dim=x.unit / y.unit)) - elif isinstance(x, (jax.Array, np.ndarray)) and isinstance(y, (jax.Array, np.ndarray)): - return jnp.remainder(x, y, *args, **kwargs) - elif isinstance(x, Quantity): - return _return_check_unitless(Quantity(jnp.remainder(x.value, y, *args, **kwargs), dim=x.unit % y)) - elif isinstance(y, Quantity): - return _return_check_unitless(Quantity(jnp.remainder(x, y.value, *args, **kwargs), dim=x % y.unit)) - else: - raise ValueError(f'Unsupported types : {type(x)} abd {type(y)} for {jnp.remainder.__name__}') - - -convolve = wrap_math_funcs_change_unit_binary(jnp.convolve, lambda x, y: x * y) - - -# math funcs only accept unitless (unary) -# --------------------------------------- - -def wrap_math_funcs_only_accept_unitless_unary(func): - def f(x, *args, **kwargs): - if isinstance(x, Quantity): - fail_for_dimension_mismatch( - x, - error_message="%s expects a dimensionless argument but got {value}" % func.__name__, - value=x, - ) - return func(jnp.array(x.value), *args, **kwargs) - else: - return func(x, *args, **kwargs) - - f.__module__ = 'brainunit.math' - return f - - -exp = wrap_math_funcs_only_accept_unitless_unary(jnp.exp) -exp2 = wrap_math_funcs_only_accept_unitless_unary(jnp.exp2) -expm1 = wrap_math_funcs_only_accept_unitless_unary(jnp.expm1) -log = wrap_math_funcs_only_accept_unitless_unary(jnp.log) -log10 = wrap_math_funcs_only_accept_unitless_unary(jnp.log10) -log1p = wrap_math_funcs_only_accept_unitless_unary(jnp.log1p) -log2 = wrap_math_funcs_only_accept_unitless_unary(jnp.log2) -arccos = wrap_math_funcs_only_accept_unitless_unary(jnp.arccos) -arccosh = wrap_math_funcs_only_accept_unitless_unary(jnp.arccosh) -arcsin = wrap_math_funcs_only_accept_unitless_unary(jnp.arcsin) -arcsinh = wrap_math_funcs_only_accept_unitless_unary(jnp.arcsinh) -arctan = wrap_math_funcs_only_accept_unitless_unary(jnp.arctan) -arctanh = wrap_math_funcs_only_accept_unitless_unary(jnp.arctanh) -cos = wrap_math_funcs_only_accept_unitless_unary(jnp.cos) -cosh = wrap_math_funcs_only_accept_unitless_unary(jnp.cosh) -sin = wrap_math_funcs_only_accept_unitless_unary(jnp.sin) -sinc = wrap_math_funcs_only_accept_unitless_unary(jnp.sinc) -sinh = wrap_math_funcs_only_accept_unitless_unary(jnp.sinh) -tan = wrap_math_funcs_only_accept_unitless_unary(jnp.tan) -tanh = wrap_math_funcs_only_accept_unitless_unary(jnp.tanh) -deg2rad = wrap_math_funcs_only_accept_unitless_unary(jnp.deg2rad) -rad2deg = wrap_math_funcs_only_accept_unitless_unary(jnp.rad2deg) -degrees = wrap_math_funcs_only_accept_unitless_unary(jnp.degrees) -radians = wrap_math_funcs_only_accept_unitless_unary(jnp.radians) -angle = wrap_math_funcs_only_accept_unitless_unary(jnp.angle) -percentile = wrap_math_funcs_only_accept_unitless_unary(jnp.percentile) -nanpercentile = wrap_math_funcs_only_accept_unitless_unary(jnp.nanpercentile) -quantile = wrap_math_funcs_only_accept_unitless_unary(jnp.quantile) -nanquantile = wrap_math_funcs_only_accept_unitless_unary(jnp.nanquantile) - - -# math funcs only accept unitless (binary) -# ---------------------------------------- - -def wrap_math_funcs_only_accept_unitless_binary(func): - def f(x, y, *args, **kwargs): - if isinstance(x, Quantity): - x_value = x.value - if isinstance(y, Quantity): - y_value = y.value - if isinstance(x, Quantity) or isinstance(y, Quantity): - fail_for_dimension_mismatch( - x, - error_message="%s expects a dimensionless argument but got {value}" % func.__name__, - value=x, - ) - fail_for_dimension_mismatch( - y, - error_message="%s expects a dimensionless argument but got {value}" % func.__name__, - value=y, - ) - return func(jnp.array(x_value), jnp.array(y_value), *args, **kwargs) - else: - return func(x, y, *args, **kwargs) - - f.__module__ = 'brainunit.math' - return f - - -hypot = wrap_math_funcs_only_accept_unitless_binary(jnp.hypot) -arctan2 = wrap_math_funcs_only_accept_unitless_binary(jnp.arctan2) -logaddexp = wrap_math_funcs_only_accept_unitless_binary(jnp.logaddexp) -logaddexp2 = wrap_math_funcs_only_accept_unitless_binary(jnp.logaddexp2) - - -# math funcs remove unit (unary) -# ------------------------------ -def wrap_math_funcs_remove_unit_unary(func): - def f(x, *args, **kwargs): - if isinstance(x, Quantity): - return func(x.value, *args, **kwargs) - else: - return func(x, *args, **kwargs) - - f.__module__ = 'brainunit.math' - return f - - -signbit = wrap_math_funcs_remove_unit_unary(jnp.signbit) -sign = wrap_math_funcs_remove_unit_unary(jnp.sign) -histogram = wrap_math_funcs_remove_unit_unary(jnp.histogram) -bincount = wrap_math_funcs_remove_unit_unary(jnp.bincount) - - -# math funcs remove unit (binary) -# ------------------------------- -def wrap_math_funcs_remove_unit_binary(func): - def f(x, y, *args, **kwargs): - if isinstance(x, Quantity): - x_value = x.value - if isinstance(y, Quantity): - y_value = y.value - if isinstance(x, Quantity) or isinstance(y, Quantity): - return func(jnp.array(x_value), jnp.array(y_value), *args, **kwargs) - else: - return func(x, y, *args, **kwargs) - - f.__module__ = 'brainunit.math' - return f - - -corrcoef = wrap_math_funcs_remove_unit_binary(jnp.corrcoef) -correlate = wrap_math_funcs_remove_unit_binary(jnp.correlate) -cov = wrap_math_funcs_remove_unit_binary(jnp.cov) -digitize = wrap_math_funcs_remove_unit_binary(jnp.digitize) - -# array manipulation -# ------------------ - -reshape = _compatible_with_quantity(jnp.reshape) -moveaxis = _compatible_with_quantity(jnp.moveaxis) -transpose = _compatible_with_quantity(jnp.transpose) -swapaxes = _compatible_with_quantity(jnp.swapaxes) -concatenate = _compatible_with_quantity(jnp.concatenate) -stack = _compatible_with_quantity(jnp.stack) -vstack = _compatible_with_quantity(jnp.vstack) -row_stack = vstack -hstack = _compatible_with_quantity(jnp.hstack) -dstack = _compatible_with_quantity(jnp.dstack) -column_stack = _compatible_with_quantity(jnp.column_stack) -split = _compatible_with_quantity(jnp.split) -dsplit = _compatible_with_quantity(jnp.dsplit) -hsplit = _compatible_with_quantity(jnp.hsplit) -vsplit = _compatible_with_quantity(jnp.vsplit) -tile = _compatible_with_quantity(jnp.tile) -repeat = _compatible_with_quantity(jnp.repeat) -unique = _compatible_with_quantity(jnp.unique) -append = _compatible_with_quantity(jnp.append) -flip = _compatible_with_quantity(jnp.flip) -fliplr = _compatible_with_quantity(jnp.fliplr) -flipud = _compatible_with_quantity(jnp.flipud) -roll = _compatible_with_quantity(jnp.roll) -atleast_1d = _compatible_with_quantity(jnp.atleast_1d) -atleast_2d = _compatible_with_quantity(jnp.atleast_2d) -atleast_3d = _compatible_with_quantity(jnp.atleast_3d) -expand_dims = _compatible_with_quantity(jnp.expand_dims) -squeeze = _compatible_with_quantity(jnp.squeeze) -sort = _compatible_with_quantity(jnp.sort) - -max = _compatible_with_quantity(jnp.max) -min = _compatible_with_quantity(jnp.min) - -amax = max -amin = min - -choose = _compatible_with_quantity(jnp.choose) -block = _compatible_with_quantity(jnp.block) -compress = _compatible_with_quantity(jnp.compress) -diagflat = _compatible_with_quantity(jnp.diagflat) - -# return jax.numpy.Array, not Quantity -argsort = _compatible_with_quantity(jnp.argsort, return_quantity=False) -argmax = _compatible_with_quantity(jnp.argmax, return_quantity=False) -argmin = _compatible_with_quantity(jnp.argmin, return_quantity=False) -argwhere = _compatible_with_quantity(jnp.argwhere, return_quantity=False) -nonzero = _compatible_with_quantity(jnp.nonzero, return_quantity=False) -flatnonzero = _compatible_with_quantity(jnp.flatnonzero, return_quantity=False) -searchsorted = _compatible_with_quantity(jnp.searchsorted, return_quantity=False) -extract = _compatible_with_quantity(jnp.extract, return_quantity=False) -count_nonzero = _compatible_with_quantity(jnp.count_nonzero, return_quantity=False) - - -def wrap_function_to_method(func): - @wraps(func) - def f(x, *args, **kwargs): - if isinstance(x, Quantity): - return Quantity(func(x.value, *args, **kwargs), dim=x.unit) - else: - return func(x, *args, **kwargs) - - f.__module__ = 'brainunit.math' - return f - - -diagonal = wrap_function_to_method(jnp.diagonal) -ravel = wrap_function_to_method(jnp.ravel) - - -# Elementwise bit operations (unary) -# ---------------------------------- - -def wrap_elementwise_bit_operation_unary(func): - def f(x, *args, **kwargs): - if isinstance(x, Quantity): - raise ValueError(f'Expected integers, got {x}') - elif isinstance(x, (jax.Array, np.ndarray)): - return func(x, *args, **kwargs) - else: - raise ValueError(f'Unsupported types {type(x)} for {func.__name__}') - - f.__module__ = 'brainunit.math' - return f - - -bitwise_not = wrap_elementwise_bit_operation_unary(jnp.bitwise_not) -invert = wrap_elementwise_bit_operation_unary(jnp.invert) -left_shift = wrap_elementwise_bit_operation_unary(jnp.left_shift) -right_shift = wrap_elementwise_bit_operation_unary(jnp.right_shift) - - -# Elementwise bit operations (binary) -# ----------------------------------- - -def wrap_elementwise_bit_operation_binary(func): - def f(x, y, *args, **kwargs): - if isinstance(x, Quantity) or isinstance(y, Quantity): - raise ValueError(f'Expected integers, got {x} and {y}') - elif isinstance(x, (jax.Array, np.ndarray)) and isinstance(y, (jax.Array, np.ndarray)): - return func(x, y, *args, **kwargs) - else: - raise ValueError(f'Unsupported types {type(x)} and {type(y)} for {func.__name__}') - - f.__module__ = 'brainunit.math' - return f - - -bitwise_and = wrap_elementwise_bit_operation_binary(jnp.bitwise_and) -bitwise_or = wrap_elementwise_bit_operation_binary(jnp.bitwise_or) -bitwise_xor = wrap_elementwise_bit_operation_binary(jnp.bitwise_xor) - - -# logic funcs (unary) -# ------------------- - -def wrap_logic_func_unary(func): - def f(x, *args, **kwargs): - if isinstance(x, Quantity): - raise ValueError(f'Expected booleans, got {x}') - elif isinstance(x, (jax.Array, np.ndarray)): - return func(x, *args, **kwargs) - else: - raise ValueError(f'Unsupported types {type(x)} for {func.__name__}') - - f.__module__ = 'brainunit.math' - return f - - -all = wrap_logic_func_unary(jnp.all) -any = wrap_logic_func_unary(jnp.any) -alltrue = all -sometrue = any -logical_not = wrap_logic_func_unary(jnp.logical_not) - - -# logic funcs (binary) -# -------------------- - -def wrap_logic_func_binary(func): - def f(x, y, *args, **kwargs): - if isinstance(x, Quantity) and isinstance(y, Quantity): - fail_for_dimension_mismatch(x, y) - return func(x.value, y.value, *args, **kwargs) - elif isinstance(x, (jax.Array, np.ndarray)) and isinstance(y, (jax.Array, np.ndarray)): - return func(x, y, *args, **kwargs) - else: - raise ValueError(f'Unsupported types {type(x)} and {type(y)} for {func.__name__}') - - f.__module__ = 'brainunit.math' - return f - - -equal = wrap_logic_func_binary(jnp.equal) -not_equal = wrap_logic_func_binary(jnp.not_equal) -greater = wrap_logic_func_binary(jnp.greater) -greater_equal = wrap_logic_func_binary(jnp.greater_equal) -less = wrap_logic_func_binary(jnp.less) -less_equal = wrap_logic_func_binary(jnp.less_equal) -array_equal = wrap_logic_func_binary(jnp.array_equal) -isclose = wrap_logic_func_binary(jnp.isclose) -allclose = wrap_logic_func_binary(jnp.allclose) -logical_and = wrap_logic_func_binary(jnp.logical_and) - -logical_or = wrap_logic_func_binary(jnp.logical_or) -logical_xor = wrap_logic_func_binary(jnp.logical_xor) - - -# indexing funcs -# -------------- -@set_module_as('brainunit.math') -def where(condition, *args, **kwds): # pylint: disable=C0111 - condition = jnp.asarray(condition) - if len(args) == 0: - # nothing to do - return jnp.where(condition, *args, **kwds) - elif len(args) == 2: - # check that x and y have the same dimensions - fail_for_dimension_mismatch( - args[0], args[1], "x and y need to have the same dimensions" - ) - new_args = [] - for arg in args: - if isinstance(arg, Quantity): - new_args.append(arg.value) - if is_unitless(args[0]): - if len(new_args) == 2: - return jnp.where(condition, *new_args, **kwds) - else: - return jnp.where(condition, *args, **kwds) - else: - # as both arguments have the same unit, just use the first one's - dimensionless_args = [jnp.asarray(arg.value) if isinstance(arg, Quantity) else jnp.asarray(arg) for arg in args] - return Quantity.with_units( - jnp.where(condition, *dimensionless_args), args[0].dim - ) - else: - # illegal number of arguments - if len(args) == 1: - raise ValueError("where() takes 2 or 3 positional arguments but 1 was given") - elif len(args) > 2: - raise TypeError("where() takes 2 or 3 positional arguments but {} were given".format(len(args))) - - -tril_indices = jnp.tril_indices - - -@set_module_as('brainunit.math') -def tril_indices_from(arr, k=0): - if isinstance(arr, Quantity): - return jnp.tril_indices_from(arr.value, k=k) - else: - return jnp.tril_indices_from(arr, k=k) - - -triu_indices = jnp.triu_indices - - -@set_module_as('brainunit.math') -def triu_indices_from(arr, k=0): - if isinstance(arr, Quantity): - return jnp.triu_indices_from(arr.value, k=k) - else: - return jnp.triu_indices_from(arr, k=k) - - -@set_module_as('brainunit.math') -def take(a, indices, axis=None, mode=None): - if isinstance(a, Quantity): - return a.take(indices, axis=axis, mode=mode) - else: - return jnp.take(a, indices, axis=axis, mode=mode) - - -@set_module_as('brainunit.math') -def select(condlist: list[Union[jnp.array, np.ndarray]], choicelist: Union[Quantity, jax.Array, np.ndarray], default=0): - from builtins import all as origin_all - from builtins import any as origin_any - if origin_all(isinstance(choice, Quantity) for choice in choicelist): - if origin_any(choice.dim != choicelist[0].dim for choice in choicelist): - raise ValueError("All choices must have the same unit") - else: - return Quantity(jnp.select(condlist, [choice.value for choice in choicelist], default=default), - dim=choicelist[0].dim) - elif origin_all(isinstance(choice, (jax.Array, np.ndarray)) for choice in choicelist): - return jnp.select(condlist, choicelist, default=default) - else: - raise ValueError(f"Unsupported types : {type(condlist)} and {type(choicelist)} for select") - - -# window funcs -# ------------ - -def wrap_window_funcs(func): - def f(*args, **kwargs): - return Quantity(func(*args, **kwargs)) - - f.__module__ = 'brainunit.math' - return f - - -bartlett = wrap_window_funcs(jnp.bartlett) -blackman = wrap_window_funcs(jnp.blackman) -hamming = wrap_window_funcs(jnp.hamming) -hanning = wrap_window_funcs(jnp.hanning) -kaiser = wrap_window_funcs(jnp.kaiser) - -# constants -# --------- -e = jnp.e -pi = jnp.pi -inf = jnp.inf - -# linear algebra -# -------------- -dot = wrap_math_funcs_change_unit_binary(jnp.dot, lambda x, y: x * y) -vdot = wrap_math_funcs_change_unit_binary(jnp.vdot, lambda x, y: x * y) -inner = wrap_math_funcs_change_unit_binary(jnp.inner, lambda x, y: x * y) -outer = wrap_math_funcs_change_unit_binary(jnp.outer, lambda x, y: x * y) -kron = wrap_math_funcs_change_unit_binary(jnp.kron, lambda x, y: x * y) -matmul = wrap_math_funcs_change_unit_binary(jnp.matmul, lambda x, y: x * y) -trace = wrap_math_funcs_keep_unit_unary(jnp.trace) - -# data types -# ---------- -dtype = jnp.dtype - - -@set_module_as('brainunit.math') -def finfo(a): - if isinstance(a, Quantity): - return jnp.finfo(a.value) - else: - return jnp.finfo(a) - - -@set_module_as('brainunit.math') -def iinfo(a): - if isinstance(a, Quantity): - return jnp.iinfo(a.value) - else: - return jnp.iinfo(a) - - -# more -# ---- -@set_module_as('brainunit.math') -def broadcast_arrays(*args): - from builtins import all as origin_all - from builtins import any as origin_any - if origin_all(isinstance(arg, Quantity) for arg in args): - if origin_any(arg.dim != args[0].dim for arg in args): - raise ValueError("All arguments must have the same unit") - return Quantity(jnp.broadcast_arrays(*[arg.value for arg in args]), dim=args[0].dim) - elif origin_all(isinstance(arg, (jax.Array, np.ndarray)) for arg in args): - return jnp.broadcast_arrays(*args) - else: - raise ValueError(f"Unsupported types : {type(args)} for broadcast_arrays") - - -broadcast_shapes = jnp.broadcast_shapes - - -@set_module_as('brainunit.math') -def einsum( - subscripts, /, - *operands, - out: None = None, - optimize: Union[str, bool] = "optimal", - precision: jax.lax.PrecisionLike = None, - preferred_element_type: Union[jax.typing.DTypeLike, None] = None, - _dot_general: Callable[..., jax.Array] = jax.lax.dot_general, -) -> Union[jax.Array, Quantity]: - operands = (subscripts, *operands) - if out is not None: - raise NotImplementedError("The 'out' argument to jnp.einsum is not supported.") - spec = operands[0] if isinstance(operands[0], str) else None - optimize = 'optimal' if optimize is True else optimize - - # Allow handling of shape polymorphism - non_constant_dim_types = { - type(d) for op in operands if not isinstance(op, str) - for d in np.shape(op) if not jax.core.is_constant_dim(d) - } - if not non_constant_dim_types: - contract_path = opt_einsum.contract_path - else: - from jax._src.numpy.lax_numpy import _default_poly_einsum_handler - contract_path = _default_poly_einsum_handler - - operands, contractions = contract_path( - *operands, einsum_call=True, use_blas=True, optimize=optimize) - - unit = None - for i in range(len(contractions) - 1): - if contractions[i][4] == 'False': - - fail_for_dimension_mismatch( - Quantity([], dim=unit), operands[i + 1], 'einsum' - ) - elif contractions[i][4] == 'DOT' or \ - contractions[i][4] == 'TDOT' or \ - contractions[i][4] == 'GEMM' or \ - contractions[i][4] == 'OUTER/EINSUM': - if i == 0: - if isinstance(operands[i], Quantity) and isinstance(operands[i + 1], Quantity): - unit = operands[i].dim * operands[i + 1].dim - elif isinstance(operands[i], Quantity): - unit = operands[i].dim - elif isinstance(operands[i + 1], Quantity): - unit = operands[i + 1].dim - else: - if isinstance(operands[i + 1], Quantity): - unit = unit * operands[i + 1].dim - - contractions = tuple((a, frozenset(b), c) for a, b, c, *_ in contractions) - - einsum = jax.jit(_einsum, static_argnums=(1, 2, 3, 4), inline=True) - if spec is not None: - einsum = jax.named_call(einsum, name=spec) - operands = [op.value if isinstance(op, Quantity) else op for op in operands] - r = einsum(operands, contractions, precision, # type: ignore[operator] - preferred_element_type, _dot_general) - if unit is not None: - return Quantity(r, dim=unit) - else: - return r - - -@set_module_as('brainunit.math') -def gradient( - f: Union[jax.Array, np.ndarray, Quantity], - *varargs: Union[jax.Array, np.ndarray, Quantity], - axis: Union[int, Sequence[int], None] = None, - edge_order: Union[int, None] = None, -) -> Union[jax.Array, list[jax.Array], Quantity, list[Quantity]]: - if edge_order is not None: - raise NotImplementedError("The 'edge_order' argument to jnp.gradient is not supported.") - - if len(varargs) == 0: - if isinstance(f, Quantity) and not is_unitless(f): - return Quantity(jnp.gradient(f.value, axis=axis), dim=f.unit) - else: - return jnp.gradient(f) - elif len(varargs) == 1: - unit = get_unit(f) / get_unit(varargs[0]) - if unit is None or unit == DIMENSIONLESS: - return jnp.gradient(f, varargs[0], axis=axis) - else: - return [Quantity(r, dim=unit) for r in jnp.gradient(f.value, varargs[0].value, axis=axis)] - else: - unit_list = [get_unit(f) / get_unit(v) for v in varargs] - f = f.value if isinstance(f, Quantity) else f - varargs = [v.value if isinstance(v, Quantity) else v for v in varargs] - result_list = jnp.gradient(f, *varargs, axis=axis) - return [Quantity(r, dim=unit) if unit is not None else r for r, unit in zip(result_list, unit_list)] - - -@set_module_as('brainunit.math') -def intersect1d( - ar1: Union[jax.Array, np.ndarray], - ar2: Union[jax.Array, np.ndarray], - assume_unique: bool = False, - return_indices: bool = False -) -> Union[jax.Array, Quantity, tuple[Union[jax.Array, Quantity], jax.Array, jax.Array]]: - fail_for_dimension_mismatch(ar1, ar2, 'intersect1d') - unit = None - if isinstance(ar1, Quantity): - unit = ar1.unit - ar1 = ar1.value if isinstance(ar1, Quantity) else ar1 - ar2 = ar2.value if isinstance(ar2, Quantity) else ar2 - result = jnp.intersect1d(ar1, ar2, assume_unique=assume_unique, return_indices=return_indices) - if return_indices: - if unit is not None: - return (Quantity(result[0], dim=unit), result[1], result[2]) - else: - return result - else: - if unit is not None: - return Quantity(result, dim=unit) - else: - return result - - -nan_to_num = wrap_math_funcs_keep_unit_unary(jnp.nan_to_num) -nanargmax = _compatible_with_quantity(jnp.nanargmax, return_quantity=False) -nanargmin = _compatible_with_quantity(jnp.nanargmin, return_quantity=False) - -rot90 = wrap_math_funcs_keep_unit_unary(jnp.rot90) -tensordot = wrap_math_funcs_change_unit_binary(jnp.tensordot, lambda x, y: x * y) diff --git a/brainunit/math/_compat_numpy_array_creation.py b/brainunit/math/_compat_numpy_array_creation.py new file mode 100644 index 0000000..156a553 --- /dev/null +++ b/brainunit/math/_compat_numpy_array_creation.py @@ -0,0 +1,720 @@ +# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from collections.abc import Sequence +from functools import wraps +from typing import (Callable, Union, Optional, Any) + +import brainstate as bst +import jax +import jax.numpy as jnp +import numpy as np +from brainstate._utils import set_module_as +from jax import Array + +from .._base import (DIMENSIONLESS, + Quantity, + Unit, + fail_for_dimension_mismatch, + is_unitless, + ) + +__all__ = [ + # array creation + 'full', 'full_like', 'eye', 'identity', 'diag', 'tri', 'tril', 'triu', + 'empty', 'empty_like', 'ones', 'ones_like', 'zeros', 'zeros_like', + 'array', 'asarray', 'arange', 'linspace', 'logspace', 'fill_diagonal', + 'array_split', 'meshgrid', 'vander', +] + + +def wrap_array_creation_function(func: Callable) -> Callable: + @wraps(func) + def f(*args, unit: Unit = None, **kwargs): + if unit is not None: + assert isinstance(unit, Unit), f'unit must be an instance of Unit, got {type(unit)}' + return func(*args, **kwargs) * unit + else: + return func(*args, **kwargs) + + f.__module__ = 'brainunit.math' + return f + + +@wrap_array_creation_function +def full(shape: Sequence[int], + fill_value: Any, + dtype: Optional[Any] = None, + unit: Optional[Unit] = None) -> Union[Array, Quantity]: + return jnp.full(shape, fill_value, dtype=dtype) + + +@wrap_array_creation_function +def eye(N: int, + M: Optional[int] = None, + k: int = 0, + dtype: Optional[Any] = None, + unit: Optional[Unit] = None) -> Union[Array, Quantity]: + return jnp.eye(N, M, k, dtype=dtype) + + +@wrap_array_creation_function +def identity(n: int, + dtype: Optional[Any] = None, + unit: Optional[Unit] = None) -> Union[Array, Quantity]: + return jnp.identity(n, dtype=dtype) + + +@wrap_array_creation_function +def tri(N: int, + M: Optional[int] = None, + k: int = 0, + dtype: Optional[Any] = None, + unit: Optional[Unit] = None) -> Union[Array, Quantity]: + return jnp.tri(N, M, k, dtype=dtype) + + +@wrap_array_creation_function +def empty(shape: Sequence[int], + dtype: Optional[Any] = None, + unit: Optional[Unit] = None) -> Union[Array, Quantity]: + return jnp.empty(shape, dtype=dtype) + + +@wrap_array_creation_function +def ones(shape: Sequence[int], + dtype: Optional[Any] = None, + unit: Optional[Unit] = None) -> Union[Array, Quantity]: + return jnp.ones(shape, dtype=dtype) + + +@wrap_array_creation_function +def zeros(shape: Sequence[int], + dtype: Optional[Any] = None, + unit: Optional[Unit] = None) -> Union[Array, Quantity]: + return jnp.zeros(shape, dtype=dtype) + + +full.__doc__ = ''' + Returns a Quantity of `shape` and `unit`, filled with `fill_value` if `unit` is provided. + else return an array of `shape` filled with `fill_value`. + + Args: + shape: sequence of integers, describing the shape of the output array. + fill_value: the value to fill the new array with. + dtype: the type of the output array, or `None`. If not `None`, `fill_value` + will be cast to `dtype`. + sharding: an optional sharding specification for the resulting array, + note, sharding will currently be ignored in jitted mode, this might change + in the future. + unit: the unit of the output array, or `None`. + + Returns: + Union[jax.Array, Quantity]: Quantity if `unit` is provided, else an array. + ''' + +eye.__doc__ = """ + Returns a Quantity of `shape` and `unit`, representing an identity matrix if `unit` is provided. + else return an identity matrix of `shape`. + + Args: + n: the number of rows (and columns) in the output array. + k: the index of the diagonal: 0 (the default) refers to the main diagonal, + a positive value refers to an upper diagonal, and a negative value to a + lower diagonal. + dtype: the type of the output array, or `None`. If not `None`, elements + will be cast to `dtype`. + sharding: an optional sharding specification for the resulting array, + note, sharding will currently be ignored in jitted mode, this might change + in the future. + unit: the unit of the output array, or `None`. + + Returns: + Union[jax.Array, Quantity]: Quantity if `unit` is provided, else an array. +""" + +identity.__doc__ = """ + Returns a Quantity of `shape` and `unit`, representing an identity matrix if `unit` is provided. + else return an identity matrix of `shape`. + + Args: + n: the number of rows (and columns) in the output array. + dtype: the type of the output array, or `None`. If not `None`, elements + will be cast to `dtype`. + sharding: an optional sharding specification for the resulting array, + note, sharding will currently be ignored in jitted mode, this might change + in the future. + unit: the unit of the output array, or `None`. + + Returns: + Union[jax.Array, Quantity]: Quantity if `unit` is provided, else an array. +""" + +tri.__doc__ = """ + Returns a Quantity of `shape` and `unit`, representing a triangular matrix if `unit` is provided. + else return a triangular matrix of `shape`. + + Args: + n: the number of rows in the output array. + m: the number of columns with default being `n`. + k: the index of the diagonal: 0 (the default) refers to the main diagonal, + a positive value refers to an upper diagonal, and a negative value to a + lower diagonal. + dtype: the type of the output array, or `None`. If not `None`, elements + will be cast to `dtype`. + sharding: an optional sharding specification for the resulting array, + note, sharding will currently be ignored in jitted mode, this might change + in the future. + unit: the unit of the output array, or `None`. + + Returns: + Union[jax.Array, Quantity]: Quantity if `unit` is provided, else an array. +""" + +# empty +empty.__doc__ = """ + Returns a Quantity of `shape` and `unit`, with uninitialized values if `unit` is provided. + else return an array of `shape` with uninitialized values. + + Args: + shape: sequence of integers, describing the shape of the output array. + dtype: the type of the output array, or `None`. If not `None`, elements + will be of type `dtype`. + sharding: an optional sharding specification for the resulting array, + note, sharding will currently be ignored in jitted mode, this might change + in the future. + unit: the unit of the output array, or `None`. + + Returns: + Union[jax.Array, Quantity]: Quantity if `unit` is provided, else an array. +""" + +# ones +ones.__doc__ = """ + Returns a Quantity of `shape` and `unit`, filled with 1 if `unit` is provided. + else return an array of `shape` filled with 1. + + Args: + shape: sequence of integers, describing the shape of the output array. + dtype: the type of the output array, or `None`. If not `None`, elements + will be cast to `dtype`. + sharding: an optional sharding specification for the resulting array, + note, sharding will currently be ignored in jitted mode, this might change + in the future. + unit: the unit of the output array, or `None`. + + Returns: + Union[jax.Array, Quantity]: Quantity if `unit` is provided, else an array. +""" + +# zeros +zeros.__doc__ = """ + Returns a Quantity of `shape` and `unit`, filled with 0 if `unit` is provided. + else return an array of `shape` filled with 0. + + Args: + shape: sequence of integers, describing the shape of the output array. + dtype: the type of the output array, or `None`. If not `None`, elements + will be cast to `dtype`. + sharding: an optional sharding specification for the resulting array, + note, sharding will currently be ignored in jitted mode, this might change + in the future. + unit: the unit of the output array, or `None`. + + Returns: + Union[jax.Array, Quantity]: Quantity if `unit` is provided, else an array. +""" + + +@set_module_as('brainunit.math') +def full_like(a: Union[Quantity, bst.typing.ArrayLike], + fill_value: Union[bst.typing.ArrayLike], + unit: Unit = None, + dtype: Optional[bst.typing.DTypeLike] = None, + shape: Any = None) -> Union[Quantity, jax.Array]: + ''' + Return a Quantity of `a` and `unit`, filled with `fill_value` if `unit` is provided. + else return an array of `a` filled with `fill_value`. + + Args: + a: array_like, Quantity, shape, or dtype + fill_value: scalar or array_like + unit: Unit, optional + dtype: data-type, optional + shape: sequence of ints, optional + + Returns: + Union[jax.Array, Quantity]: Quantity if `unit` is provided, else an array. + ''' + if unit is not None: + assert isinstance(unit, Unit) + if isinstance(a, Quantity): + return jnp.full_like(a.value, fill_value, dtype=dtype, shape=shape) * unit + else: + return jnp.full_like(a, fill_value, dtype=dtype, shape=shape) * unit + else: + return jnp.full_like(a, fill_value, dtype=dtype, shape=shape) + + +@set_module_as('brainunit.math') +def diag(a: Union[Quantity, bst.typing.ArrayLike], + k: int = 0, + unit: Unit = None) -> Union[Quantity, jax.Array]: + ''' + Extract a diagonal or construct a diagonal array. + + Args: + a: array_like, Quantity + k: int, optional + unit: Unit, optional + + Returns: + Union[jax.Array, Quantity]: Quantity if `unit` is provided, else an array. + ''' + if unit is not None: + assert isinstance(unit, Unit) + if isinstance(a, Quantity): + return jnp.diag(a.value, k=k) * unit + else: + return jnp.diag(a, k=k) * unit + else: + return jnp.diag(a, k=k) + + +@set_module_as('brainunit.math') +def tril(a: Union[Quantity, bst.typing.ArrayLike], + k: int = 0, + unit: Unit = None) -> Union[Quantity, jax.Array]: + ''' + Lower triangle of an array. + + Args: + a: array_like, Quantity + k: int, optional + unit: Unit, optional + + Returns: + Union[jax.Array, Quantity]: Quantity if `unit` is provided, else an array. + ''' + if unit is not None: + assert isinstance(unit, Unit) + if isinstance(a, Quantity): + return jnp.tril(a.value, k=k) * unit + else: + return jnp.tril(a, k=k) * unit + else: + return jnp.tril(a, k=k) + + +@set_module_as('brainunit.math') +def triu(a: Union[Quantity, bst.typing.ArrayLike], + k: int = 0, + unit: Unit = None) -> Union[Quantity, jax.Array]: + ''' + Upper triangle of an array. + + Args: + a: array_like, Quantity + k: int, optional + unit: Unit, optional + + Returns: + Union[jax.Array, Quantity]: Quantity if `unit` is provided, else an array. + ''' + if unit is not None: + assert isinstance(unit, Unit) + if isinstance(a, Quantity): + return jnp.triu(a.value, k=k) * unit + else: + return jnp.triu(a, k=k) * unit + else: + return jnp.triu(a, k=k) + + +@set_module_as('brainunit.math') +def empty_like(a: Union[Quantity, bst.typing.ArrayLike], + dtype: Optional[bst.typing.DTypeLike] = None, + shape: Any = None, + unit: Unit = None) -> Union[Quantity, jax.Array]: + ''' + Return a Quantity of `a` and `unit`, with uninitialized values if `unit` is provided. + else return an array of `a` with uninitialized values. + + Args: + a: array_like, Quantity, shape, or dtype + dtype: data-type, optional + shape: sequence of ints, optional + unit: Unit, optional + + Returns: + Union[jax.Array, Quantity]: Quantity if `unit` is provided, else an array. + ''' + if unit is not None: + assert isinstance(unit, Unit) + if isinstance(a, Quantity): + return jnp.empty_like(a.value, dtype=dtype, shape=shape) * unit + else: + return jnp.empty_like(a, dtype=dtype, shape=shape) * unit + else: + return jnp.empty_like(a, dtype=dtype, shape=shape) + + +@set_module_as('brainunit.math') +def ones_like(a: Union[Quantity, bst.typing.ArrayLike], + dtype: Optional[bst.typing.DTypeLike] = None, + shape: Any = None, + unit: Unit = None) -> Union[Quantity, jax.Array]: + ''' + Return a Quantity of `a` and `unit`, filled with 1 if `unit` is provided. + else return an array of `a` filled with 1. + + Args: + a: array_like, Quantity, shape, or dtype + dtype: data-type, optional + shape: sequence of ints, optional + unit: Unit, optional + + Returns: + Union[jax.Array, Quantity]: Quantity if `unit` is provided, else an array. + ''' + if unit is not None: + assert isinstance(unit, Unit) + if isinstance(a, Quantity): + return jnp.ones_like(a.value, dtype=dtype, shape=shape) * unit + else: + return jnp.ones_like(a, dtype=dtype, shape=shape) * unit + else: + return jnp.ones_like(a, dtype=dtype, shape=shape) + + +@set_module_as('brainunit.math') +def zeros_like(a: Union[Quantity, bst.typing.ArrayLike], + dtype: Optional[bst.typing.DTypeLike] = None, + shape: Any = None, + unit: Unit = None) -> Union[Quantity, jax.Array]: + ''' + Return a Quantity of `a` and `unit`, filled with 0 if `unit` is provided. + else return an array of `a` filled with 0. + + Args: + a: array_like, Quantity, shape, or dtype + dtype: data-type, optional + shape: sequence of ints, optional + unit: Unit, optional + + Returns: + Union[jax.Array, Quantity]: Quantity if `unit` is provided, else an array. + ''' + if unit is not None: + assert isinstance(unit, Unit) + if isinstance(a, Quantity): + return jnp.zeros_like(a.value, dtype=dtype, shape=shape) * unit + else: + return jnp.zeros_like(a, dtype=dtype, shape=shape) * unit + else: + return jnp.zeros_like(a, dtype=dtype, shape=shape) + + +@set_module_as('brainunit.math') +def asarray( + a: Union[Quantity, bst.typing.ArrayLike, Sequence[Quantity]], + dtype: Optional[bst.typing.DTypeLike] = None, + order: Optional[str] = None, + unit: Optional[Unit] = None, +) -> Union[Quantity, jax.Array]: + from builtins import all as origin_all + from builtins import any as origin_any + if isinstance(a, Quantity): + return Quantity(jnp.asarray(a.value, dtype=dtype, order=order), dim=a.dim) + elif isinstance(a, (jax.Array, np.ndarray)): + return jnp.asarray(a, dtype=dtype, order=order) + # list[Quantity] + elif isinstance(a, Sequence) and origin_all(isinstance(x, Quantity) for x in a): + # check all elements have the same unit + if origin_any(x.dim != a[0].dim for x in a): + raise ValueError('Units do not match for asarray operation.') + values = [x.value for x in a] + unit = a[0].dim + # Convert the values to a jnp.ndarray and create a Quantity object + return Quantity(jnp.asarray(values, dtype=dtype, order=order), dim=unit) + else: + return jnp.asarray(a, dtype=dtype, order=order) + + +array = asarray + + +@set_module_as('brainunit.math') +def arange(*args, **kwargs): + ''' + Return a Quantity of `arange` and `unit`, with uninitialized values if `unit` is provided. + + Args: + start: number, Quantity, optional + stop: number, Quantity, optional + step: number, optional + dtype: dtype, optional + unit: Unit, optional + + Returns: + Union[jax.Array, Quantity]: Quantity if start and stop are Quantities that have the same unit, else an array. + ''' + # arange has a bit of a complicated argument structure unfortunately + # we leave the actual checking of the number of arguments to numpy, though + + # default values + start = kwargs.pop("start", 0) + step = kwargs.pop("step", 1) + stop = kwargs.pop("stop", None) + if len(args) == 1: + if stop is not None: + raise TypeError("Duplicate definition of 'stop'") + stop = args[0] + elif len(args) == 2: + if start != 0: + raise TypeError("Duplicate definition of 'start'") + if stop is not None: + raise TypeError("Duplicate definition of 'stop'") + start, stop = args + elif len(args) == 3: + if start != 0: + raise TypeError("Duplicate definition of 'start'") + if stop is not None: + raise TypeError("Duplicate definition of 'stop'") + if step != 1: + raise TypeError("Duplicate definition of 'step'") + start, stop, step = args + elif len(args) > 3: + raise TypeError("Need between 1 and 3 non-keyword arguments") + + if stop is None: + raise TypeError("Missing stop argument.") + if stop is not None and not is_unitless(stop): + start = Quantity(start, dim=stop.dim) + + fail_for_dimension_mismatch( + start, + stop, + error_message=( + "Start value {start} and stop value {stop} have to have the same units." + ), + start=start, + stop=stop, + ) + fail_for_dimension_mismatch( + stop, + step, + error_message=( + "Stop value {stop} and step value {step} have to have the same units." + ), + stop=stop, + step=step, + ) + unit = getattr(stop, "dim", DIMENSIONLESS) + # start is a position-only argument in numpy 2.0 + # https://numpy.org/devdocs/release/2.0.0-notes.html#arange-s-start-argument-is-positional-only + # TODO: check whether this is still the case in the final release + if start == 0: + return Quantity( + jnp.arange( + start=start.value if isinstance(start, Quantity) else jnp.asarray(start), + stop=stop.value if isinstance(stop, Quantity) else jnp.asarray(stop), + step=step.value if isinstance(step, Quantity) else jnp.asarray(step), + **kwargs, + ), + dim=unit, + ) + else: + return Quantity( + jnp.arange( + start.value if isinstance(start, Quantity) else jnp.asarray(start), + stop=stop.value if isinstance(stop, Quantity) else jnp.asarray(stop), + step=step.value if isinstance(step, Quantity) else jnp.asarray(step), + **kwargs, + ), + dim=unit, + ) + + +@set_module_as('brainunit.math') +def linspace(start: Union[Quantity, bst.typing.ArrayLike], + stop: Union[Quantity, bst.typing.ArrayLike], + num: int = 50, + endpoint: Optional[bool] = True, + retstep: Optional[bool] = False, + dtype: Optional[bst.typing.DTypeLike] = None) -> Union[Quantity, jax.Array]: + ''' + Return a Quantity of `linspace` and `unit`, with uninitialized values if `unit` is provided. + + Args: + start: number, Quantity + stop: number, Quantity + num: int, optional + endpoint: bool, optional + retstep: bool, optional + dtype: dtype, optional + + Returns: + Union[jax.Array, Quantity]: Quantity if start and stop are Quantities that have the same unit, else an array. + ''' + fail_for_dimension_mismatch( + start, + stop, + error_message="Start value {start} and stop value {stop} have to have the same units.", + start=start, + stop=stop, + ) + unit = getattr(start, "dim", DIMENSIONLESS) + start = start.value if isinstance(start, Quantity) else start + stop = stop.value if isinstance(stop, Quantity) else stop + + result = jnp.linspace(start, stop, num=num, endpoint=endpoint, retstep=retstep, dtype=dtype) + return Quantity(result, dim=unit) + + +@set_module_as('brainunit.math') +def logspace(start: Union[Quantity, bst.typing.ArrayLike], + stop: Union[Quantity, bst.typing.ArrayLike], + num: Optional[int] = 50, + endpoint: Optional[bool] = True, + base: Optional[float] = 10.0, + dtype: Optional[bst.typing.DTypeLike] = None): + ''' + Return a Quantity of `logspace` and `unit`, with uninitialized values if `unit` is provided. + + Args: + start: number, Quantity + stop: number, Quantity + num: int, optional + endpoint: bool, optional + base: float, optional + dtype: dtype, optional + + Returns: + Union[jax.Array, Quantity]: Quantity if start and stop are Quantities that have the same unit, else an array. + ''' + fail_for_dimension_mismatch( + start, + stop, + error_message="Start value {start} and stop value {stop} have to have the same units.", + start=start, + stop=stop, + ) + unit = getattr(start, "dim", DIMENSIONLESS) + start = start.value if isinstance(start, Quantity) else start + stop = stop.value if isinstance(stop, Quantity) else stop + + result = jnp.logspace(start, stop, num=num, endpoint=endpoint, base=base, dtype=dtype) + return Quantity(result, dim=unit) + + +@set_module_as('brainunit.math') +def fill_diagonal(a: Union[Quantity, bst.typing.ArrayLike], + val: Union[Quantity, bst.typing.ArrayLike], + wrap: Optional[bool] = False, + inplace: Optional[bool] = True) -> Union[Quantity, jax.Array]: + ''' + Fill the main diagonal of the given array of `a` with `val`. + + Args: + a: array_like, Quantity + val: scalar, Quantity + wrap: bool, optional + inplace: bool, optional + + Returns: + Union[jax.Array, Quantity]: Quantity if `a` and `val` are Quantities that have the same unit, else an array. + ''' + if isinstance(a, Quantity) and isinstance(val, Quantity): + fail_for_dimension_mismatch(a, val) + return Quantity(jnp.fill_diagonal(a.value, val.value, wrap=wrap, inplace=inplace), dim=a.dim) + elif isinstance(a, (jax.Array, np.ndarray)) and isinstance(val, (jax.Array, np.ndarray)): + return jnp.fill_diagonal(a, val, wrap=wrap, inplace=inplace) + elif is_unitless(a) or is_unitless(val): + return jnp.fill_diagonal(a, val, wrap=wrap, inplace=inplace) + else: + raise ValueError(f'Unsupported types : {type(a)} abd {type(val)} for fill_diagonal') + + +@set_module_as('brainunit.math') +def array_split(ary: Union[Quantity, bst.typing.ArrayLike], + indices_or_sections: Union[int, bst.typing.ArrayLike], + axis: Optional[int] = 0) -> Union[list[Quantity], list[Array]]: + ''' + Split an array into multiple sub-arrays. + + Args: + ary: array_like, Quantity + indices_or_sections: int, array_like + axis: int, optional + + Returns: + Union[jax.Array, Quantity]: Quantity if `ary` is a Quantity, else an array. + ''' + if isinstance(ary, Quantity): + return [Quantity(x, dim=ary.dim) for x in jnp.array_split(ary.value, indices_or_sections, axis)] + elif isinstance(ary, (jax.Array, np.ndarray)): + return jnp.array_split(ary, indices_or_sections, axis) + else: + raise ValueError(f'Unsupported type: {type(ary)} for array_split') + + +@set_module_as('brainunit.math') +def meshgrid(*xi: Union[Quantity, bst.typing.ArrayLike], + copy: Optional[bool] = True, + sparse: Optional[bool] = False, + indexing: Optional[str] = 'xy'): + ''' + Return coordinate matrices from coordinate vectors. + + Args: + xi: array_like, Quantity + copy: bool, optional + sparse: bool, optional + indexing: str, optional + + Returns: + Union[jax.Array, Quantity]: Quantity if `xi` are Quantities that have the same unit, else an array. + ''' + from builtins import all as origin_all + if origin_all(isinstance(x, Quantity) for x in xi): + fail_for_dimension_mismatch(*xi) + return Quantity(jnp.meshgrid(*[x.value for x in xi], copy=copy, sparse=sparse, indexing=indexing), dim=xi[0].dim) + elif origin_all(isinstance(x, (jax.Array, np.ndarray)) for x in xi): + return jnp.meshgrid(*xi, copy=copy, sparse=sparse, indexing=indexing) + else: + raise ValueError(f'Unsupported types : {type(xi)} for meshgrid') + + +@set_module_as('brainunit.math') +def vander(x: Union[Quantity, bst.typing.ArrayLike], + N: Optional[bool] = None, + increasing: Optional[bool] = False) -> Union[Quantity, jax.Array]: + ''' + Generate a Vandermonde matrix. + + Args: + x: array_like, Quantity + N: int, optional + increasing: bool, optional + + Returns: + Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. + ''' + if isinstance(x, Quantity): + return Quantity(jnp.vander(x.value, N=N, increasing=increasing), dim=x.dim) + elif isinstance(x, (jax.Array, np.ndarray)): + return jnp.vander(x, N=N, increasing=increasing) + else: + raise ValueError(f'Unsupported type: {type(x)} for vander') diff --git a/brainunit/math/_compat_numpy_array_manipulation.py b/brainunit/math/_compat_numpy_array_manipulation.py new file mode 100644 index 0000000..c4a7c26 --- /dev/null +++ b/brainunit/math/_compat_numpy_array_manipulation.py @@ -0,0 +1,821 @@ +# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from collections.abc import Sequence +from functools import wraps +from typing import (Union, Optional, Tuple, List) + +import jax +import jax.numpy as jnp +from jax import Array + +from ._utils import _compatible_with_quantity +from .._base import (Quantity, + ) + +__all__ = [ + # array manipulation + 'reshape', 'moveaxis', 'transpose', 'swapaxes', 'row_stack', + 'concatenate', 'stack', 'vstack', 'hstack', 'dstack', 'column_stack', + 'split', 'dsplit', 'hsplit', 'vsplit', 'tile', 'repeat', 'unique', + 'append', 'flip', 'fliplr', 'flipud', 'roll', 'atleast_1d', 'atleast_2d', + 'atleast_3d', 'expand_dims', 'squeeze', 'sort', 'argsort', 'argmax', 'argmin', + 'argwhere', 'nonzero', 'flatnonzero', 'searchsorted', 'extract', + 'count_nonzero', 'max', 'min', 'amax', 'amin', 'block', 'compress', + 'diagflat', 'diagonal', 'choose', 'ravel', +] + + +# array manipulation +# ------------------ + + +@_compatible_with_quantity() +def reshape(a: Union[Array, Quantity], shape: Union[int, Tuple[int, ...]], order: str = 'C') -> Union[Array, Quantity]: + return jnp.reshape(a, shape, order) + + +@_compatible_with_quantity() +def moveaxis(a: Union[Array, Quantity], source: Union[int, Tuple[int, ...]], + destination: Union[int, Tuple[int, ...]]) -> Union[Array, Quantity]: + return jnp.moveaxis(a, source, destination) + + +@_compatible_with_quantity() +def transpose(a: Union[Array, Quantity], axes: Optional[Union[int, Tuple[int, ...]]] = None) -> Union[Array, Quantity]: + return jnp.transpose(a, axes) + + +@_compatible_with_quantity() +def swapaxes(a: Union[Array, Quantity], axis1: int, axis2: int) -> Union[Array, Quantity]: + return jnp.swapaxes(a, axis1, axis2) + + +@_compatible_with_quantity() +def concatenate(arrays: Union[Sequence[Array], Sequence[Quantity]], axis: Optional[int] = None) -> Union[ + Array, Quantity]: + return jnp.concatenate(arrays, axis) + + +@_compatible_with_quantity() +def stack(arrays: Union[Sequence[Array], Sequence[Quantity]], axis: int = 0) -> Union[Array, Quantity]: + return jnp.stack(arrays, axis) + + +@_compatible_with_quantity() +def vstack(arrays: Union[Sequence[Array], Sequence[Quantity]]) -> Union[Array, Quantity]: + return jnp.vstack(arrays) + + +row_stack = vstack + + +@_compatible_with_quantity() +def hstack(arrays: Union[Sequence[Array], Sequence[Quantity]]) -> Union[Array, Quantity]: + return jnp.hstack(arrays) + + +@_compatible_with_quantity() +def dstack(arrays: Union[Sequence[Array], Sequence[Quantity]]) -> Union[Array, Quantity]: + return jnp.dstack(arrays) + + +@_compatible_with_quantity() +def column_stack(arrays: Union[Sequence[Array], Sequence[Quantity]]) -> Union[Array, Quantity]: + return jnp.column_stack(arrays) + + +@_compatible_with_quantity() +def split(a: Union[Array, Quantity], indices_or_sections: Union[int, Sequence[int]], axis: int = 0) -> Union[ + List[Array], List[Quantity]]: + return jnp.split(a, indices_or_sections, axis) + + +@_compatible_with_quantity() +def dsplit(a: Union[Array, Quantity], indices_or_sections: Union[int, Sequence[int]]) -> Union[ + List[Array], List[Quantity]]: + return jnp.dsplit(a, indices_or_sections) + + +@_compatible_with_quantity() +def hsplit(a: Union[Array, Quantity], indices_or_sections: Union[int, Sequence[int]]) -> Union[ + List[Array], List[Quantity]]: + return jnp.hsplit(a, indices_or_sections) + + +@_compatible_with_quantity() +def vsplit(a: Union[Array, Quantity], indices_or_sections: Union[int, Sequence[int]]) -> Union[ + List[Array], List[Quantity]]: + return jnp.vsplit(a, indices_or_sections) + + +@_compatible_with_quantity() +def tile(A: Union[Array, Quantity], reps: Union[int, Tuple[int, ...]]) -> Union[Array, Quantity]: + return jnp.tile(A, reps) + + +@_compatible_with_quantity() +def repeat(a: Union[Array, Quantity], repeats: Union[int, Tuple[int, ...]], axis: Optional[int] = None) -> Union[ + Array, Quantity]: + return jnp.repeat(a, repeats, axis) + + +@_compatible_with_quantity() +def unique(a: Union[Array, Quantity], return_index: bool = False, return_inverse: bool = False, + return_counts: bool = False, axis: Optional[int] = None) -> Union[Array, Quantity]: + return jnp.unique(a, return_index, return_inverse, return_counts, axis) + + +@_compatible_with_quantity() +def append(arr: Union[Array, Quantity], values: Union[Array, Quantity], axis: Optional[int] = None) -> Union[ + Array, Quantity]: + return jnp.append(arr, values, axis) + + +@_compatible_with_quantity() +def flip(m: Union[Array, Quantity], axis: Optional[Union[int, Tuple[int, ...]]] = None) -> Union[Array, Quantity]: + return jnp.flip(m, axis) + + +@_compatible_with_quantity() +def fliplr(m: Union[Array, Quantity]) -> Union[Array, Quantity]: + return jnp.fliplr(m) + + +@_compatible_with_quantity() +def flipud(m: Union[Array, Quantity]) -> Union[Array, Quantity]: + return jnp.flipud(m) + + +@_compatible_with_quantity() +def roll(a: Union[Array, Quantity], shift: Union[int, Tuple[int, ...]], + axis: Optional[Union[int, Tuple[int, ...]]] = None) -> Union[Array, Quantity]: + return jnp.roll(a, shift, axis) + + +@_compatible_with_quantity() +def atleast_1d(*arys: Union[Array, Quantity]) -> Union[Array, Quantity]: + return jnp.atleast_1d(*arys) + + +@_compatible_with_quantity() +def atleast_2d(*arys: Union[Array, Quantity]) -> Union[Array, Quantity]: + return jnp.atleast_2d(*arys) + + +@_compatible_with_quantity() +def atleast_3d(*arys: Union[Array, Quantity]) -> Union[Array, Quantity]: + return jnp.atleast_3d(*arys) + + +@_compatible_with_quantity() +def expand_dims(a: Union[Array, Quantity], axis: int) -> Union[Array, Quantity]: + return jnp.expand_dims(a, axis) + + +@_compatible_with_quantity() +def squeeze(a: Union[Array, Quantity], axis: Optional[Union[int, Tuple[int, ...]]] = None) -> Union[Array, Quantity]: + return jnp.squeeze(a, axis) + + +@_compatible_with_quantity() +def sort(a: Union[Array, Quantity], + axis: Optional[int] = -1, + kind: None = None, + order: None = None, + stable: bool = True, + descending: bool = False, ) -> Union[Array, Quantity]: + return jnp.sort(a, axis, kind=kind, order=order, stable=stable, descending=descending) + + +@_compatible_with_quantity() +def argsort(a: Union[Array, Quantity], + axis: Optional[int] = -1, + kind: None = None, + order: None = None, + stable: bool = True, + descending: bool = False, ) -> Array: + return jnp.argsort(a, axis, kind=kind, order=order, stable=stable, descending=descending) + + +@_compatible_with_quantity() +def max(a: Union[Array, Quantity], axis: Optional[int] = None, out: Optional[Array] = None, + keepdims: bool = False) -> Union[Array, Quantity]: + return jnp.max(a, axis, out, keepdims) + + +@_compatible_with_quantity() +def min(a: Union[Array, Quantity], axis: Optional[int] = None, out: Optional[Array] = None, + keepdims: bool = False) -> Union[Array, Quantity]: + return jnp.min(a, axis, out, keepdims) + + +@_compatible_with_quantity() +def choose(a: Union[Array, Quantity], choices: Sequence[Union[Array, Quantity]]) -> Union[Array, Quantity]: + return jnp.choose(a, choices) + + +@_compatible_with_quantity() +def block(arrays: Sequence[Union[Array, Quantity]]) -> Union[Array, Quantity]: + return jnp.block(arrays) + + +@_compatible_with_quantity() +def compress(condition: Union[Array, Quantity], a: Union[Array, Quantity], axis: Optional[int] = None) -> Union[ + Array, Quantity]: + return jnp.compress(condition, a, axis) + + +@_compatible_with_quantity() +def diagflat(v: Union[Array, Quantity], k: int = 0) -> Union[Array, Quantity]: + return jnp.diagflat(v, k) + + +# return jax.numpy.Array, not Quantity + +@_compatible_with_quantity(return_quantity=False) +def argmax(a: Union[Array, Quantity], axis: Optional[int] = None, out: Optional[Array] = None) -> Array: + return jnp.argmax(a, axis, out) + + +@_compatible_with_quantity(return_quantity=False) +def argmin(a: Union[Array, Quantity], axis: Optional[int] = None, out: Optional[Array] = None) -> Array: + return jnp.argmin(a, axis, out) + + +@_compatible_with_quantity(return_quantity=False) +def argwhere(a: Union[Array, Quantity]) -> Array: + return jnp.argwhere(a) + + +@_compatible_with_quantity(return_quantity=False) +def nonzero(a: Union[Array, Quantity]) -> Tuple[Array, ...]: + return jnp.nonzero(a) + + +@_compatible_with_quantity(return_quantity=False) +def flatnonzero(a: Union[Array, Quantity]) -> Array: + return jnp.flatnonzero(a) + + +@_compatible_with_quantity(return_quantity=False) +def searchsorted(a: Union[Array, Quantity], v: Union[Array, Quantity], side: str = 'left', + sorter: Optional[Array] = None) -> Array: + return jnp.searchsorted(a, v, side, sorter) + + +@_compatible_with_quantity(return_quantity=False) +def extract(condition: Union[Array, Quantity], arr: Union[Array, Quantity]) -> Array: + return jnp.extract(condition, arr) + + +@_compatible_with_quantity(return_quantity=False) +def count_nonzero(a: Union[Array, Quantity], axis: Optional[int] = None) -> Array: + return jnp.count_nonzero(a, axis) + + +amax = max +amin = min + +# docs for the functions above +reshape.__doc__ = ''' + Return a reshaped copy of an array or a Quantity. + + Args: + a: input array or Quantity to reshape + shape: integer or sequence of integers giving the new shape, which must match the + size of the input array. If any single dimension is given size ``-1``, it will be + replaced with a value such that the output has the correct size. + order: ``'F'`` or ``'C'``, specifies whether the reshape should apply column-major + (fortran-style, ``"F"``) or row-major (C-style, ``"C"``) order; default is ``"C"``. + brainunit does not support ``order="A"``. + + Returns: + reshaped copy of input array with the specified shape. +''' + +moveaxis.__doc__ = ''' + Moves axes of an array to new positions. Other axes remain in their original order. + + Args: + a: array_like, Quantity + source: int or sequence of ints + destination: int or sequence of ints + + Returns: + Union[jax.Array, Quantity] a Quantity if a is a Quantity, otherwise a jax.Array +''' + +transpose.__doc__ = ''' + Returns a view of the array with axes transposed. + + Args: + a: array_like, Quantity + axes: tuple or list of ints, optional + + Returns: + Union[jax.Array, Quantity] a Quantity if a is a Quantity, otherwise a jax.Array +''' + +swapaxes.__doc__ = ''' + Interchanges two axes of an array. + + Args: + a: array_like, Quantity + axis1: int + axis2: int + + Returns: + Union[jax.Array, Quantity] a Quantity if a is a Quantity, otherwise a jax.Array +''' + +concatenate.__doc__ = ''' + Join a sequence of arrays along an existing axis. + + Args: + arrays: sequence of array_like, Quantity + axis: int, optional + + Returns: + Union[jax.Array, Quantity] a Quantity if all input arrays are Quantity, otherwise a jax.Array +''' + +stack.__doc__ = ''' + Join a sequence of arrays along a new axis. + + Args: + arrays: sequence of array_like, Quantity + axis: int + + Returns: + Union[jax.Array, Quantity] a Quantity if all input arrays are Quantity, otherwise a jax.Array +''' + +vstack.__doc__ = ''' + Stack arrays in sequence vertically (row wise). + + Args: + arrays: sequence of array_like, Quantity + + Returns: + Union[jax.Array, Quantity] a Quantity if all input arrays are Quantity, otherwise a jax.numpy.Array +''' + +hstack.__doc__ = ''' + Stack arrays in sequence horizontally (column wise). + + Args: + arrays: sequence of array_like, Quantity + + Returns: + Union[jax.Array, Quantity] a Quantity if all input arrays are Quantity, otherwise a jax.Array +''' + +dstack.__doc__ = ''' + Stack arrays in sequence depth wise (along third axis). + + Args: + arrays: sequence of array_like, Quantity + + Returns: + Union[jax.Array, Quantity] a Quantity if all input arrays are Quantity, otherwise a jax.Array +''' + +column_stack.__doc__ = ''' + Stack 1-D arrays as columns into a 2-D array. + + Args: + arrays: sequence of 1-D or 2-D array_like, Quantity + + Returns: + Union[jax.Array, Quantity] a Quantity if all input arrays are Quantity, otherwise a jax.Array +''' + +split.__doc__ = ''' + Split an array into multiple sub-arrays. + + Args: + a: array_like, Quantity + indices_or_sections: int or 1-D array + axis: int, optional + + Returns: + Union[jax.Array, Quantity] a list of Quantity if a is a Quantity, otherwise a list of jax.Array +''' + +dsplit.__doc__ = ''' + Split array along third axis (depth). + + Args: + a: array_like, Quantity + indices_or_sections: int or 1-D array + + Returns: + Union[jax.Array, Quantity] a list of Quantity if a is a Quantity, otherwise a list of jax.Array +''' + +hsplit.__doc__ = ''' + Split an array into multiple sub-arrays horizontally (column-wise). + + Args: + a: array_like, Quantity + indices_or_sections: int or 1-D array + + Returns: + Union[jax.Array, Quantity] a list of Quantity if a is a Quantity, otherwise a list of jax.Array +''' + +vsplit.__doc__ = ''' + Split an array into multiple sub-arrays vertically (row-wise). + + Args: + a: array_like, Quantity + indices_or_sections: int or 1-D array + + Returns: + Union[jax.Array, Quantity] a list of Quantity if a is a Quantity, otherwise a list of jax.Array +''' + +tile.__doc__ = ''' + Construct an array by repeating A the number of times given by reps. + + Args: + A: array_like, Quantity + reps: array_like + + Returns: + Union[jax.Array, Quantity] a Quantity if A is a Quantity, otherwise a jax.Array +''' + +repeat.__doc__ = ''' + Repeat elements of an array. + + Args: + a: array_like, Quantity + repeats: array_like + axis: int, optional + + Returns: + Union[jax.Array, Quantity] a Quantity if a is a Quantity, otherwise a jax.Array +''' + +unique.__doc__ = ''' + Find the unique elements of an array. + + Args: + a: array_like, Quantity + return_index: bool, optional + return_inverse: bool, optional + return_counts: bool, optional + axis: int or None, optional + + Returns: + Union[jax.Array, Quantity] a Quantity if a is a Quantity, otherwise a jax.Array +''' + +append.__doc__ = ''' + Append values to the end of an array. + + Args: + arr: array_like, Quantity + values: array_like, Quantity + axis: int, optional + + Returns: + Union[jax.Array, Quantity] a Quantity if arr and values are Quantity, otherwise a jax.Array +''' + +flip.__doc__ = ''' + Reverse the order of elements in an array along the given axis. + + Args: + m: array_like, Quantity + axis: int or tuple of ints, optional + + Returns: + Union[jax.Array, Quantity] a Quantity if m is a Quantity, otherwise a jax.Array +''' + +fliplr.__doc__ = ''' + Flip array in the left/right direction. + + Args: + m: array_like, Quantity + + Returns: + Union[jax.Array, Quantity] a Quantity if m is a Quantity, otherwise a jax.Array +''' + +flipud.__doc__ = ''' + Flip array in the up/down direction. + + Args: + m: array_like, Quantity + + Returns: + Union[jax.Array, Quantity] a Quantity if m is a Quantity, otherwise a jax.Array +''' + +roll.__doc__ = ''' + Roll array elements along a given axis. + + Args: + a: array_like, Quantity + shift: int or tuple of ints + axis: int or tuple of ints, optional + + Returns: + Union[jax.Array, Quantity] a Quantity if a is a Quantity, otherwise a jax.Array +''' + +atleast_1d.__doc__ = ''' + View inputs as arrays with at least one dimension. + + Args: + *args: array_like, Quantity + + Returns: + Union[jax.Array, Quantity] a Quantity if any input is a Quantity, otherwise a jax.Array +''' + +atleast_2d.__doc__ = ''' + View inputs as arrays with at least two dimensions. + + Args: + *args: array_like, Quantity + + Returns: + Union[jax.Array, Quantity] a Quantity if any input is a Quantity, otherwise a jax.Array +''' + +atleast_3d.__doc__ = ''' + View inputs as arrays with at least three dimensions. + + Args: + *args: array_like, Quantity + + Returns: + Union[jax.Array, Quantity] a Quantity if any input is a Quantity, otherwise a jax.Array +''' + +expand_dims.__doc__ = ''' + Expand the shape of an array. + + Args: + a: array_like, Quantity + axis: int or tuple of ints + + Returns: + Union[jax.Array, Quantity] a Quantity if a is a Quantity, otherwise a jax.Array +''' + +squeeze.__doc__ = ''' + Remove single-dimensional entries from the shape of an array. + + Args: + a: array_like, Quantity + axis: None or int or tuple of ints, optional + + Returns: + Union[jax.Array, Quantity] a Quantity if a is a Quantity, otherwise a jax.Array +''' + +sort.__doc__ = ''' + Return a sorted copy of an array. + + Args: + a: array_like, Quantity + axis: int or None, optional + kind: {'quicksort', 'mergesort', 'heapsort', 'stable'}, optional + order: str or list of str, optional + + Returns: + Union[jax.Array, Quantity] a Quantity if a is a Quantity, otherwise a jax.Array +''' +max.__doc__ = ''' + Return the maximum of an array or maximum along an axis. + + Args: + a: array_like, Quantity + axis: int or tuple of ints, optional + keepdims: bool, optional + + Returns: + Union[jax.Array, Quantity] a Quantity if a is a Quantity, otherwise a jax.Array +''' + +min.__doc__ = ''' + Return the minimum of an array or minimum along an axis. + + Args: + a: array_like, Quantity + axis: int or tuple of ints, optional + keepdims: bool, optional + + Returns: + Union[jax.Array, Quantity] a Quantity if a is a Quantity, otherwise a jax.Array +''' + +choose.__doc__ = ''' + Use an index array to construct a new array from a set of choices. + + Args: + a: array_like, Quantity + choices: array_like, Quantity + + Returns: + Union[jax.Array, Quantity] a Quantity if a and choices are Quantity, otherwise a jax.Array +''' + +block.__doc__ = ''' + Assemble an nd-array from nested lists of blocks. + + Args: + arrays: sequence of array_like, Quantity + + Returns: + Union[jax.Array, Quantity] a Quantity if all input arrays are Quantity, otherwise a jax.Array +''' + +compress.__doc__ = ''' + Return selected slices of an array along given axis. + + Args: + condition: array_like, Quantity + a: array_like, Quantity + axis: int, optional + + Returns: + Union[jax.Array, Quantity] a Quantity if a is a Quantity, otherwise a jax.Array +''' + +diagflat.__doc__ = ''' + Create a two-dimensional array with the flattened input as a diagonal. + + Args: + a: array_like, Quantity + offset: int, optional + + Returns: + Union[jax.Array, Quantity] a Quantity if a is a Quantity, otherwise a jax.Array +''' + +argsort.__doc__ = ''' + Returns the indices that would sort an array. + + Args: + a: array_like, Quantity + axis: int or None, optional + kind: {'quicksort', 'mergesort', 'heapsort'}, optional + order: str or list of str, optional + + Returns: + jax.Array jax.numpy.Array (does not return a Quantity) +''' + +argmax.__doc__ = ''' + Returns indices of the max value along an axis. + + Args: + a: array_like, Quantity + axis: int, optional + out: array, optional + + Returns: + jax.Array: an array (does not return a Quantity) +''' + +argmin.__doc__ = ''' + Returns indices of the min value along an axis. + + Args: + a: array_like, Quantity + axis: int, optional + out: array, optional + + Returns: + jax.Array: an array (does not return a Quantity) +''' + +argwhere.__doc__ = ''' + Find indices of non-zero elements. + + Args: + a: array_like, Quantity + + Returns: + jax.Array: an array (does not return a Quantity) +''' + +nonzero.__doc__ = ''' + Return the indices of the elements that are non-zero. + + Args: + a: array_like, Quantity + + Returns: + jax.Array: an array (does not return a Quantity) +''' + +flatnonzero.__doc__ = ''' + Return indices that are non-zero in the flattened version of a. + + Args: + a: array_like, Quantity + + Returns: + jax.Array: an array (does not return a Quantity) +''' + +searchsorted.__doc__ = ''' + Find indices where elements should be inserted to maintain order. + + Args: + a: array_like, Quantity + v: array_like, Quantity + side: {'left', 'right'}, optional + + Returns: + jax.Array: an array (does not return a Quantity) +''' + +extract.__doc__ = ''' + Return the elements of an array that satisfy some condition. + + Args: + condition: array_like, Quantity + a: array_like, Quantity + + Returns: + jax.Array: an array (does not return a Quantity) +''' + +count_nonzero.__doc__ = ''' + Counts the number of non-zero values in the array a. + + Args: + a: array_like, Quantity + axis: int or tuple of ints, optional + + Returns: + jax.Array: an array (does not return a Quantity) +''' + + +def wrap_function_to_method(func): + @wraps(func) + def f(x, *args, **kwargs): + if isinstance(x, Quantity): + return Quantity(func(x.value, *args, **kwargs), dim=x.dim) + else: + return func(x, *args, **kwargs) + + f.__module__ = 'brainunit.math' + return f + + +@wrap_function_to_method +def diagonal(a: Union[jax.Array, Quantity], offset: int = 0, axis1: int = 0, axis2: int = 1) -> Union[ + jax.Array, Quantity]: + return jnp.diagonal(a, offset, axis1, axis2) + + +@wrap_function_to_method +def ravel(a: Union[jax.Array, Quantity], order: str = 'C') -> Union[jax.Array, Quantity]: + return jnp.ravel(a, order) + + +diagonal.__doc__ = ''' + Return specified diagonals. + + Args: + a: array_like, Quantity + offset: int, optional + axis1: int, optional + axis2: int, optional + + Returns: + Union[jax.Array, Quantity]: a Quantity if a is a Quantity, otherwise a jax.numpy.Array +''' + +ravel.__doc__ = ''' + Return a contiguous flattened array. + + Args: + a: array_like, Quantity + order: {'C', 'F', 'A', 'K'}, optional + + Returns: + Union[jax.Array, Quantity]: a Quantity if a is a Quantity, otherwise a jax.numpy.Array +''' diff --git a/brainunit/math/_compat_numpy_funcs_accept_unitless.py b/brainunit/math/_compat_numpy_funcs_accept_unitless.py new file mode 100644 index 0000000..c87890a --- /dev/null +++ b/brainunit/math/_compat_numpy_funcs_accept_unitless.py @@ -0,0 +1,588 @@ +# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from functools import wraps +from typing import (Union) + +import brainstate as bst +import jax.numpy as jnp +from jax import Array + +from .._base import (Quantity, + fail_for_dimension_mismatch, + ) + +__all__ = [ + # math funcs only accept unitless (unary) + 'exp', 'exp2', 'expm1', 'log', 'log10', 'log1p', 'log2', + 'arccos', 'arccosh', 'arcsin', 'arcsinh', 'arctan', + 'arctanh', 'cos', 'cosh', 'sin', 'sinc', 'sinh', 'tan', + 'tanh', 'deg2rad', 'rad2deg', 'degrees', 'radians', 'angle', + 'percentile', 'nanpercentile', 'quantile', 'nanquantile', + + # math funcs only accept unitless (binary) + 'hypot', 'arctan2', 'logaddexp', 'logaddexp2', +] + + +# math funcs only accept unitless (unary) +# --------------------------------------- + +def wrap_math_funcs_only_accept_unitless_unary(func): + @wraps(func) + def f(x, *args, **kwargs): + if isinstance(x, Quantity): + fail_for_dimension_mismatch( + x, + error_message="%s expects a dimensionless argument but got {value}" % func.__name__, + value=x, + ) + return func(jnp.array(x.value), *args, **kwargs) + else: + return func(x, *args, **kwargs) + + f.__module__ = 'brainunit.math' + return f + + +@wrap_math_funcs_only_accept_unitless_unary +def exp(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Array, Quantity]: + return jnp.exp(x) + + +@wrap_math_funcs_only_accept_unitless_unary +def exp2(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Array, Quantity]: + return jnp.exp2(x) + + +@wrap_math_funcs_only_accept_unitless_unary +def expm1(x: Union[Array, Quantity]) -> Array: + return jnp.expm1(x) + + +@wrap_math_funcs_only_accept_unitless_unary +def log(x: Union[Array, Quantity]) -> Array: + return jnp.log(x) + + +@wrap_math_funcs_only_accept_unitless_unary +def log10(x: Union[Array, Quantity]) -> Array: + return jnp.log10(x) + + +@wrap_math_funcs_only_accept_unitless_unary +def log1p(x: Union[Array, Quantity]) -> Array: + return jnp.log1p(x) + + +@wrap_math_funcs_only_accept_unitless_unary +def log2(x: Union[Array, Quantity]) -> Array: + return jnp.log2(x) + + +@wrap_math_funcs_only_accept_unitless_unary +def arccos(x: Union[Array, Quantity]) -> Array: + return jnp.arccos(x) + + +@wrap_math_funcs_only_accept_unitless_unary +def arccosh(x: Union[Array, Quantity]) -> Array: + return jnp.arccosh(x) + + +@wrap_math_funcs_only_accept_unitless_unary +def arcsin(x: Union[Array, Quantity]) -> Array: + return jnp.arcsin(x) + + +@wrap_math_funcs_only_accept_unitless_unary +def arcsinh(x: Union[Array, Quantity]) -> Array: + return jnp.arcsinh(x) + + +@wrap_math_funcs_only_accept_unitless_unary +def arctan(x: Union[Array, Quantity]) -> Array: + return jnp.arctan(x) + + +@wrap_math_funcs_only_accept_unitless_unary +def arctanh(x: Union[Array, Quantity]) -> Array: + return jnp.arctanh(x) + + +@wrap_math_funcs_only_accept_unitless_unary +def cos(x: Union[Array, Quantity]) -> Array: + return jnp.cos(x) + + +@wrap_math_funcs_only_accept_unitless_unary +def cosh(x: Union[Array, Quantity]) -> Array: + return jnp.cosh(x) + + +@wrap_math_funcs_only_accept_unitless_unary +def sin(x: Union[Array, Quantity]) -> Array: + return jnp.sin(x) + + +@wrap_math_funcs_only_accept_unitless_unary +def sinc(x: Union[Array, Quantity]) -> Array: + return jnp.sinc(x) + + +@wrap_math_funcs_only_accept_unitless_unary +def sinh(x: Union[Array, Quantity]) -> Array: + return jnp.sinh(x) + + +@wrap_math_funcs_only_accept_unitless_unary +def tan(x: Union[Array, Quantity]) -> Array: + return jnp.tan(x) + + +@wrap_math_funcs_only_accept_unitless_unary +def tanh(x: Union[Array, Quantity]) -> Array: + return jnp.tanh(x) + + +@wrap_math_funcs_only_accept_unitless_unary +def deg2rad(x: Union[Array, Quantity]) -> Array: + return jnp.deg2rad(x) + + +@wrap_math_funcs_only_accept_unitless_unary +def rad2deg(x: Union[Array, Quantity]) -> Array: + return jnp.rad2deg(x) + + +@wrap_math_funcs_only_accept_unitless_unary +def degrees(x: Union[Array, Quantity]) -> Array: + return jnp.degrees(x) + + +@wrap_math_funcs_only_accept_unitless_unary +def radians(x: Union[Array, Quantity]) -> Array: + return jnp.radians(x) + + +@wrap_math_funcs_only_accept_unitless_unary +def angle(x: Union[Array, Quantity]) -> Array: + return jnp.angle(x) + + +# docs for the functions above +exp.__doc__ = ''' + Calculate the exponential of all elements in the input array. + + Args: + x: array_like, Quantity + + Returns: + jax.Array: an array +''' + +exp2.__doc__ = ''' + Calculate 2 raised to the power of the input elements. + + Args: + x: array_like, Quantity + + Returns: + jax.Array: an array +''' + +expm1.__doc__ = ''' + Calculate the exponential of the input elements minus 1. + + Args: + x: array_like, Quantity + + Returns: + jax.Array: an array +''' + +log.__doc__ = ''' + Natural logarithm, element-wise. + + Args: + x: array_like, Quantity + + Returns: + jax.Array: an array +''' + +log10.__doc__ = ''' + Base-10 logarithm of the input elements. + + Args: + x: array_like, Quantity + + Returns: + jax.Array: an array +''' + +log1p.__doc__ = ''' + Natural logarithm of 1 + the input elements. + + Args: + x: array_like, Quantity + + Returns: + jax.Array: an array +''' + +log2.__doc__ = ''' + Base-2 logarithm of the input elements. + + Args: + x: array_like, Quantity + + Returns: + jax.Array: an array +''' + +arccos.__doc__ = ''' + Compute the arccosine of the input elements. + + Args: + x: array_like, Quantity + + Returns: + jax.Array: an array +''' + +arccosh.__doc__ = ''' + Compute the hyperbolic arccosine of the input elements. + + Args: + x: array_like, Quantity + + Returns: + jax.Array: an array +''' + +arcsin.__doc__ = ''' + Compute the arcsine of the input elements. + + Args: + x: array_like, Quantity + + Returns: + jax.Array: an array +''' + +arcsinh.__doc__ = ''' + Compute the hyperbolic arcsine of the input elements. + + Args: + x: array_like, Quantity + + Returns: + jax.Array: an array +''' + +arctan.__doc__ = ''' + Compute the arctangent of the input elements. + + Args: + x: array_like, Quantity + + Returns: + jax.Array: an array +''' + +arctanh.__doc__ = ''' + Compute the hyperbolic arctangent of the input elements. + + Args: + x: array_like, Quantity + + Returns: + jax.Array: an array +''' + +cos.__doc__ = ''' + Compute the cosine of the input elements. + + Args: + x: array_like, Quantity + + Returns: + jax.Array: an array +''' + +cosh.__doc__ = ''' + Compute the hyperbolic cosine of the input elements. + + Args: + x: array_like, Quantity + + Returns: + jax.Array: an array +''' + +sin.__doc__ = ''' + Compute the sine of the input elements. + + Args: + x: array_like, Quantity + + Returns: + jax.Array: an array +''' + +sinc.__doc__ = ''' + Compute the sinc function of the input elements. + + Args: + x: array_like, Quantity + + Returns: + jax.Array: an array +''' + +sinh.__doc__ = ''' + Compute the hyperbolic sine of the input elements. + + Args: + x: array_like, Quantity + + Returns: + jax.Array: an array +''' + +tan.__doc__ = ''' + Compute the tangent of the input elements. + + Args: + x: array_like, Quantity + + Returns: + jax.Array: an array +''' + +tanh.__doc__ = ''' + Compute the hyperbolic tangent of the input elements. + + Args: + x: array_like, Quantity + + Returns: + jax.Array: an array +''' + +deg2rad.__doc__ = ''' + Convert angles from degrees to radians. + + Args: + x: array_like, Quantity + + Returns: + jax.Array: an array +''' + +rad2deg.__doc__ = ''' + Convert angles from radians to degrees. + + Args: + x: array_like, Quantity + + Returns: + jax.Array: an array +''' + +degrees.__doc__ = ''' + Convert angles from radians to degrees. + + Args: + x: array_like, Quantity + + Returns: + jax.Array: an array +''' + +radians.__doc__ = ''' + Convert angles from degrees to radians. + + Args: + x: array_like, Quantity + + Returns: + jax.Array: an array +''' + +angle.__doc__ = ''' + Return the angle of the complex argument. + + Args: + x: array_like, Quantity + + Returns: + jax.Array: an array +''' + + +# math funcs only accept unitless (binary) +# ---------------------------------------- + +def wrap_math_funcs_only_accept_unitless_binary(func): + @wraps(func) + def f(x, y, *args, **kwargs): + if isinstance(x, Quantity): + x_value = x.value + if isinstance(y, Quantity): + y_value = y.value + if isinstance(x, Quantity) or isinstance(y, Quantity): + fail_for_dimension_mismatch( + x, + error_message="%s expects a dimensionless argument but got {value}" % func.__name__, + value=x, + ) + fail_for_dimension_mismatch( + y, + error_message="%s expects a dimensionless argument but got {value}" % func.__name__, + value=y, + ) + return func(jnp.array(x_value), jnp.array(y_value), *args, **kwargs) + else: + return func(x, y, *args, **kwargs) + + f.__module__ = 'brainunit.math' + return f + + +@wrap_math_funcs_only_accept_unitless_binary +def hypot(x: Union[Array, Quantity], y: Union[Array, Quantity]) -> Array: + return jnp.hypot(x, y) + + +@wrap_math_funcs_only_accept_unitless_binary +def arctan2(x: Union[Array, Quantity], y: Union[Array, Quantity]) -> Array: + return jnp.arctan2(x, y) + + +@wrap_math_funcs_only_accept_unitless_binary +def logaddexp(x: Union[Array, Quantity], y: Union[Array, Quantity]) -> Array: + return jnp.logaddexp(x, y) + + +@wrap_math_funcs_only_accept_unitless_binary +def logaddexp2(x: Union[Array, Quantity], y: Union[Array, Quantity]) -> Array: + return jnp.logaddexp2(x, y) + + +@wrap_math_funcs_only_accept_unitless_binary +def percentile(a: Union[Array, Quantity], q: Union[Array, Quantity], *args, **kwargs) -> Array: + return jnp.percentile(a, q, *args, **kwargs) + + +@wrap_math_funcs_only_accept_unitless_binary +def nanpercentile(a: Union[Array, Quantity], q: Union[Array, Quantity], *args, **kwargs) -> Array: + return jnp.nanpercentile(a, q, *args, **kwargs) + + +@wrap_math_funcs_only_accept_unitless_binary +def quantile(a: Union[Array, Quantity], q: Union[Array, Quantity], *args, **kwargs) -> Array: + return jnp.quantile(a, q, *args, **kwargs) + + +@wrap_math_funcs_only_accept_unitless_binary +def nanquantile(a: Union[Array, Quantity], q: Union[Array, Quantity], *args, **kwargs) -> Array: + return jnp.nanquantile(a, q, *args, **kwargs) + + +# docs for the functions above +hypot.__doc__ = ''' + Given the “legs” of a right triangle, return its hypotenuse. + + Args: + x: array_like, Quantity + y: array_like, Quantity + + Returns: + jax.Array: an array +''' + +arctan2.__doc__ = ''' + Element-wise arc tangent of `x1/x2` choosing the quadrant correctly. + + Args: + x1: array_like, Quantity + x2: array_like, Quantity + + Returns: + jax.Array: an array +''' + +logaddexp.__doc__ = ''' + Logarithm of the sum of exponentiations of the inputs. + + Args: + x1: array_like, Quantity + x2: array_like, Quantity + + Returns: + jax.Array: an array +''' + +logaddexp2.__doc__ = ''' + Logarithm of the sum of exponentiations of the inputs in base-2. + + Args: + x1: array_like, Quantity + x2: array_like, Quantity + + Returns: + jax.Array: an array +''' + +percentile.__doc__ = ''' + Compute the nth percentile of the input array along the specified axis. + + Args: + x: array_like, Quantity + + Returns: + jax.Array: an array +''' + +nanpercentile.__doc__ = ''' + Compute the nth percentile of the input array along the specified axis, ignoring NaNs. + + Args: + x: array_like, Quantity + + Returns: + jax.Array: an array +''' + +quantile.__doc__ = ''' + Compute the qth quantile of the input array along the specified axis. + + Args: + x: array_like, Quantity + + Returns: + jax.Array: an array +''' + +nanquantile.__doc__ = ''' + Compute the qth quantile of the input array along the specified axis, ignoring NaNs. + + Args: + x: array_like, Quantity + + Returns: + jax.Array: an array +''' diff --git a/brainunit/math/_compat_numpy_funcs_bit_operation.py b/brainunit/math/_compat_numpy_funcs_bit_operation.py new file mode 100644 index 0000000..1325539 --- /dev/null +++ b/brainunit/math/_compat_numpy_funcs_bit_operation.py @@ -0,0 +1,183 @@ +# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from functools import wraps +from typing import (Union) + +import brainstate as bst +import jax +import jax.numpy as jnp +import numpy as np +from jax import Array +from numpy import number + +from .._base import (Quantity, + ) + +__all__ = [ + + # Elementwise bit operations (unary) + 'bitwise_not', 'invert', + + # Elementwise bit operations (binary) + 'bitwise_and', 'bitwise_or', 'bitwise_xor', 'left_shift', 'right_shift', +] + + +# Elementwise bit operations (unary) +# ---------------------------------- + +def wrap_elementwise_bit_operation_unary(func): + @wraps(func) + def f(x, *args, **kwargs): + if isinstance(x, Quantity): + raise ValueError(f'Expected integers, got {x}') + elif isinstance(x, (jax.Array, np.ndarray)): + return func(x, *args, **kwargs) + else: + raise ValueError(f'Unsupported types {type(x)} for {func.__name__}') + + f.__module__ = 'brainunit.math' + return f + + +@wrap_elementwise_bit_operation_unary +def bitwise_not(x: Union[Quantity, bst.typing.ArrayLike]) -> Array: + return jnp.bitwise_not(x) + + +@wrap_elementwise_bit_operation_unary +def invert(x: Union[Quantity, bst.typing.ArrayLike]) -> Array: + return jnp.invert(x) + + +# docs for functions above +bitwise_not.__doc__ = ''' + Compute the bit-wise NOT of an array, element-wise. + + Args: + x: array_like + + Returns: + jax.Array: an array +''' + +invert.__doc__ = ''' + Compute bit-wise inversion, or bit-wise NOT, element-wise. + + Args: + x: array_like + + Returns: + jax.Array: an array +''' + + +# Elementwise bit operations (binary) +# ----------------------------------- + +def wrap_elementwise_bit_operation_binary(func): + @wraps(func) + def f(x, y, *args, **kwargs): + if isinstance(x, Quantity) or isinstance(y, Quantity): + raise ValueError(f'Expected integers, got {x} and {y}') + elif isinstance(x, (jax.Array, np.ndarray)) and isinstance(y, (jax.Array, np.ndarray, int, float)): + return func(x, y, *args, **kwargs) + else: + raise ValueError(f'Unsupported types {type(x)} and {type(y)} for {func.__name__}') + + f.__module__ = 'brainunit.math' + return f + + +@wrap_elementwise_bit_operation_binary +def bitwise_and(x: Union[Quantity, bst.typing.ArrayLike], y: Union[Quantity, bst.typing.ArrayLike]) -> Array: + return jnp.bitwise_and(x, y) + + +@wrap_elementwise_bit_operation_binary +def bitwise_or(x: Union[Quantity, bst.typing.ArrayLike], y: Union[Quantity, bst.typing.ArrayLike]) -> Array: + return jnp.bitwise_or(x, y) + + +@wrap_elementwise_bit_operation_binary +def bitwise_xor(x: Union[Quantity, bst.typing.ArrayLike], y: Union[Quantity, bst.typing.ArrayLike]) -> Array: + return jnp.bitwise_xor(x, y) + + +@wrap_elementwise_bit_operation_binary +def left_shift(x: Union[Quantity, bst.typing.ArrayLike], y: Union[Quantity, bst.typing.ArrayLike]) -> Array: + return jnp.left_shift(x, y) + + +@wrap_elementwise_bit_operation_binary +def right_shift(x: Union[Quantity, bst.typing.ArrayLike], y: Union[Quantity, bst.typing.ArrayLike]) -> Array: + return jnp.right_shift(x, y) + + +# docs for functions above +bitwise_and.__doc__ = ''' + Compute the bit-wise AND of two arrays element-wise. + + Args: + x: array_like + y: array_like + + Returns: + jax.Array: an array +''' + +bitwise_or.__doc__ = ''' + Compute the bit-wise OR of two arrays element-wise. + + Args: + x: array_like + y: array_like + + Returns: + jax.Array: an array +''' + +bitwise_xor.__doc__ = ''' + Compute the bit-wise XOR of two arrays element-wise. + + Args: + x: array_like + y: array_like + + Returns: + jax.Array: an array +''' + +left_shift.__doc__ = ''' + Shift the bits of an integer to the left. + + Args: + x: array_like + y: array_like + + Returns: + jax.Array: an array +''' + +right_shift.__doc__ = ''' + Shift the bits of an integer to the right. + + Args: + x: array_like + y: array_like + + Returns: + jax.Array: an array +''' diff --git a/brainunit/math/_compat_numpy_funcs_change_unit.py b/brainunit/math/_compat_numpy_funcs_change_unit.py new file mode 100644 index 0000000..227234c --- /dev/null +++ b/brainunit/math/_compat_numpy_funcs_change_unit.py @@ -0,0 +1,527 @@ +# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from collections.abc import Sequence +from functools import wraps +from typing import (Callable, Union, Optional) + +import brainstate as bst +import jax +import jax.numpy as jnp +import numpy as np +from brainstate._utils import set_module_as + +from ._compat_numpy_get_attribute import isscalar +from .._base import (DIMENSIONLESS, + Quantity, + ) +from .._base import _return_check_unitless + +__all__ = [ + + # math funcs change unit (unary) + 'reciprocal', 'prod', 'product', 'nancumprod', 'nanprod', 'cumprod', + 'cumproduct', 'var', 'nanvar', 'cbrt', 'square', 'frexp', 'sqrt', + + # math funcs change unit (binary) + 'multiply', 'divide', 'power', 'cross', 'ldexp', + 'true_divide', 'floor_divide', 'float_power', + 'divmod', 'remainder', 'convolve', +] + + +# math funcs change unit (unary) +# ------------------------------ + +def wrap_math_funcs_change_unit_unary(change_unit_func: Callable) -> Callable: + def decorator(func: Callable) -> Callable: + @wraps(func) + def f(x, *args, **kwargs): + if isinstance(x, Quantity): + return _return_check_unitless(Quantity(func(x.value, *args, **kwargs), dim=change_unit_func(x.dim))) + elif isinstance(x, (jnp.ndarray, np.ndarray)): + return func(x, *args, **kwargs) + else: + raise ValueError(f'Unsupported type: {type(x)} for {func.__name__}') + + f.__module__ = 'brainunit.math' + return f + + return decorator + + +@wrap_math_funcs_change_unit_unary(lambda x: x ** -1) +def reciprocal(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: + return jnp.reciprocal(x) + + +@wrap_math_funcs_change_unit_unary(lambda x: x ** 2) +def var(x: Union[Quantity, bst.typing.ArrayLike], + axis: Optional[Union[int, Sequence[int]]] = None, + ddof: int = 0, + keepdims: bool = False) -> Union[Quantity, jax.Array]: + return jnp.var(x, axis=axis, ddof=ddof, keepdims=keepdims) + + +@wrap_math_funcs_change_unit_unary(lambda x: x ** 2) +def nanvar(x: Union[Quantity, bst.typing.ArrayLike], + axis: Optional[Union[int, Sequence[int]]] = None, + ddof: int = 0, + keepdims: bool = False) -> Union[Quantity, jax.Array]: + return jnp.nanvar(x, axis=axis, ddof=ddof, keepdims=keepdims) + + +@wrap_math_funcs_change_unit_unary(lambda x: x * 2 ** -1) +def frexp(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: + return jnp.frexp(x) + + +@wrap_math_funcs_change_unit_unary(lambda x: x ** 0.5) +def sqrt(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: + return jnp.sqrt(x) + + +@wrap_math_funcs_change_unit_unary(lambda x: x ** (1 / 3)) +def cbrt(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: + return jnp.cbrt(x) + + +@wrap_math_funcs_change_unit_unary(lambda x: x ** 2) +def square(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: + return jnp.square(x) + + +# docs for the functions above + +reciprocal.__doc__ = ''' + Return the reciprocal of the argument. + + Args: + x: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. +''' + +var.__doc__ = ''' + Compute the variance along the specified axis. + + Args: + x: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if the final unit is the square of the unit of `x`, else an array. +''' + +nanvar.__doc__ = ''' + Compute the variance along the specified axis, ignoring NaNs. + + Args: + x: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if the final unit is the square of the unit of `x`, else an array. +''' + +frexp.__doc__ = ''' + Decompose a floating-point number into its mantissa and exponent. + + Args: + x: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Tuple of Quantity if the final unit is the product of the unit of `x` and 2 raised to the power of the exponent, else a tuple of arrays. +''' + +sqrt.__doc__ = ''' + Compute the square root of each element. + + Args: + x: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if the final unit is the square root of the unit of `x`, else an array. +''' + +cbrt.__doc__ = ''' + Compute the cube root of each element. + + Args: + x: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if the final unit is the cube root of the unit of `x`, else an array. +''' + +square.__doc__ = ''' + Compute the square of each element. + + Args: + x: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if the final unit is the square of the unit of `x`, else an array. +''' + + +@set_module_as('brainunit.math') +def prod(x: Union[Quantity, bst.typing.ArrayLike], + axis: Optional[int] = None, + dtype: Optional[bst.typing.DTypeLike] = None, + out: None = None, + keepdims: Optional[bool] = False, + initial: Union[Quantity, bst.typing.ArrayLike] = None, + where: Union[Quantity, bst.typing.ArrayLike] = None, + promote_integers: bool = True) -> Union[Quantity, jax.Array]: + ''' + Return the product of array elements over a given axis. + + Args: + x: array_like, Quantity + axis: int, optional + dtype: dtype, optional + out: array, optional + keepdims: bool, optional + initial: array_like, Quantity, optional + where: array_like, Quantity, optional + promote_integers: bool, optional + + Returns: + Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. + ''' + if isinstance(x, Quantity): + return x.prod(axis=axis, dtype=dtype, out=out, keepdims=keepdims, initial=initial, where=where, + promote_integers=promote_integers) + else: + return jnp.prod(x, axis=axis, dtype=dtype, out=out, keepdims=keepdims, initial=initial, where=where, + promote_integers=promote_integers) + + +@set_module_as('brainunit.math') +def nanprod(x: Union[Quantity, bst.typing.ArrayLike], + axis: Optional[int] = None, + dtype: Optional[bst.typing.DTypeLike] = None, + out: None = None, + keepdims: bool = False, + initial: Union[Quantity, bst.typing.ArrayLike] = None, + where: Union[Quantity, bst.typing.ArrayLike] = None): + ''' + Return the product of array elements over a given axis treating Not a Numbers (NaNs) as one. + + Args: + x: array_like, Quantity + axis: int, optional + dtype: dtype, optional + out: array, optional + keepdims: bool, optional + initial: array_like, Quantity, optional + where: array_like, Quantity, optional + + Returns: + Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. + ''' + if isinstance(x, Quantity): + return x.nanprod(axis=axis, dtype=dtype, out=out, keepdims=keepdims, initial=initial, where=where) + else: + return jnp.nanprod(x, axis=axis, dtype=dtype, out=out, keepdims=keepdims, initial=initial, where=where) + + +product = prod + + +@set_module_as('brainunit.math') +def cumprod(x: Union[Quantity, bst.typing.ArrayLike], + axis: Optional[int] = None, + dtype: Optional[bst.typing.DTypeLike] = None, + out: None = None) -> Union[Quantity, bst.typing.ArrayLike]: + ''' + Return the cumulative product of elements along a given axis. + + Args: + x: array_like, Quantity + axis: int, optional + dtype: dtype, optional + out: array, optional + + Returns: + Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. + ''' + if isinstance(x, Quantity): + return x.cumprod(axis=axis, dtype=dtype, out=out) + else: + return jnp.cumprod(x, axis=axis, dtype=dtype, out=out) + + +@set_module_as('brainunit.math') +def nancumprod(x: Union[Quantity, bst.typing.ArrayLike], + axis: Optional[int] = None, + dtype: Optional[bst.typing.DTypeLike] = None, + out: None = None) -> Union[Quantity, bst.typing.ArrayLike]: + ''' + Return the cumulative product of elements along a given axis treating Not a Numbers (NaNs) as one. + + Args: + x: array_like, Quantity + axis: int, optional + dtype: dtype, optional + out: array, optional + + Returns: + Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. + ''' + if isinstance(x, Quantity): + return x.nancumprod(axis=axis, dtype=dtype, out=out) + else: + return jnp.nancumprod(x, axis=axis, dtype=dtype, out=out) + + +cumproduct = cumprod + + +# math funcs change unit (binary) +# ------------------------------- + +def wrap_math_funcs_change_unit_binary(change_unit_func): + def decorator(func: Callable) -> Callable: + @wraps(func) + def f(x, y, *args, **kwargs): + if isinstance(x, Quantity) and isinstance(y, Quantity): + return _return_check_unitless( + Quantity(func(x.value, y.value, *args, **kwargs), dim=change_unit_func(x.dim, y.dim)) + ) + elif isinstance(x, (jax.Array, np.ndarray)) and isinstance(y, (jax.Array, np.ndarray)): + return func(x, y, *args, **kwargs) + elif isinstance(x, Quantity): + return _return_check_unitless( + Quantity(func(x.value, y, *args, **kwargs), dim=change_unit_func(x.dim, DIMENSIONLESS))) + elif isinstance(y, Quantity): + return _return_check_unitless( + Quantity(func(x, y.value, *args, **kwargs), dim=change_unit_func(DIMENSIONLESS, y.dim))) + else: + raise ValueError(f'Unsupported types : {type(x)} abd {type(y)} for {func.__name__}') + + f.__module__ = 'brainunit.math' + return f + return decorator + + +@wrap_math_funcs_change_unit_binary(lambda x, y: x * y) +def multiply(x: Union[Quantity, bst.typing.ArrayLike], y: Union[Quantity, bst.typing.ArrayLike]): + return jnp.multiply(x, y) + + +@wrap_math_funcs_change_unit_binary(lambda x, y: x / y) +def divide(x: Union[Quantity, bst.typing.ArrayLike], y: Union[Quantity, bst.typing.ArrayLike]): + return jnp.divide(x, y) + + +@wrap_math_funcs_change_unit_binary(lambda x, y: x * y) +def cross(x: Union[Quantity, bst.typing.ArrayLike], y: Union[Quantity, bst.typing.ArrayLike]): + return jnp.cross(x, y) + + +@wrap_math_funcs_change_unit_binary(lambda x, y: x * 2 ** y) +def ldexp(x: Union[Quantity, bst.typing.ArrayLike], y: Union[Quantity, bst.typing.ArrayLike]): + return jnp.ldexp(x, y) + + +@wrap_math_funcs_change_unit_binary(lambda x, y: x / y) +def true_divide(x: Union[Quantity, bst.typing.ArrayLike], y: Union[Quantity, bst.typing.ArrayLike]): + return jnp.true_divide(x, y) + + +@wrap_math_funcs_change_unit_binary(lambda x, y: x / y) +def divmod(x: Union[Quantity, bst.typing.ArrayLike], y: Union[Quantity, bst.typing.ArrayLike]): + return jnp.divmod(x, y) + + +@wrap_math_funcs_change_unit_binary(lambda x, y: x * y) +def convolve(x: Union[Quantity, bst.typing.ArrayLike], y: Union[Quantity, bst.typing.ArrayLike]): + return jnp.convolve(x, y) + + +# docs for the functions above +multiply.__doc__ = ''' + Multiply arguments element-wise. + + Args: + x: array_like, Quantity + y: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if the final unit is the product of the unit of `x` and the unit of `y`, else an array. +''' + +divide.__doc__ = ''' + Divide arguments element-wise. + + Args: + x: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if the final unit is the quotient of the unit of `x` and the unit of `y`, else an array. +''' + +cross.__doc__ = ''' + Return the cross product of two (arrays of) vectors. + + Args: + x: array_like, Quantity + y: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if the final unit is the product of the unit of `x` and the unit of `y`, else an array. +''' + +ldexp.__doc__ = ''' + Return x1 * 2**x2, element-wise. + + Args: + x: array_like, Quantity + y: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if the final unit is the product of the unit of `x` and 2 raised to the power of the unit of `y`, else an array. +''' + +true_divide.__doc__ = ''' + Returns a true division of the inputs, element-wise. + + Args: + x: array_like, Quantity + y: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if the final unit is the quotient of the unit of `x` and the unit of `y`, else an array. +''' + +divmod.__doc__ = ''' + Return element-wise quotient and remainder simultaneously. + + Args: + x: array_like, Quantity + y: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if the final unit is the quotient of the unit of `x` and the unit of `y`, else an array. +''' + +convolve.__doc__ = ''' + Returns the discrete, linear convolution of two one-dimensional sequences. + + Args: + x: array_like, Quantity + y: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if the final unit is the product of the unit of `x` and the unit of `y`, else an array. +''' + + +@set_module_as('brainunit.math') +def power(x: Union[Quantity, bst.typing.ArrayLike], + y: Union[Quantity, bst.typing.ArrayLike], ) -> Union[Quantity, jax.Array]: + ''' + First array elements raised to powers from second array, element-wise. + + Args: + x: array_like, Quantity + y: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if the final unit is the product of the unit of `x` and the unit of `y`, else an array. + ''' + if isinstance(x, Quantity) and isinstance(y, Quantity): + return _return_check_unitless(Quantity(jnp.power(x.value, y.value), dim=x.dim ** y.dim)) + elif isinstance(x, (jax.Array, np.ndarray)) and isinstance(y, (jax.Array, np.ndarray)): + return jnp.power(x, y) + elif isinstance(x, Quantity): + return _return_check_unitless(Quantity(jnp.power(x.value, y), dim=x.dim ** y)) + elif isinstance(y, Quantity): + return _return_check_unitless(Quantity(jnp.power(x, y.value), dim=x ** y.dim)) + else: + raise ValueError(f'Unsupported types : {type(x)} abd {type(y)} for {jnp.power.__name__}') + + +@set_module_as('brainunit.math') +def floor_divide(x: Union[Quantity, bst.typing.ArrayLike], + y: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: + ''' + Return the largest integer smaller or equal to the division of the inputs. + + Args: + x: array_like, Quantity + y: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if the final unit is the quotient of the unit of `x` and the unit of `y`, else an array. + ''' + if isinstance(x, Quantity) and isinstance(y, Quantity): + return _return_check_unitless(Quantity(jnp.floor_divide(x.value, y.value), dim=x.dim / y.dim)) + elif isinstance(x, (jax.Array, np.ndarray)) and isinstance(y, (jax.Array, np.ndarray)): + return jnp.floor_divide(x, y) + elif isinstance(x, Quantity): + return _return_check_unitless(Quantity(jnp.floor_divide(x.value, y), dim=x.dim / y)) + elif isinstance(y, Quantity): + return _return_check_unitless(Quantity(jnp.floor_divide(x, y.value), dim=x / y.dim)) + else: + raise ValueError(f'Unsupported types : {type(x)} abd {type(y)} for {jnp.floor_divide.__name__}') + + +@set_module_as('brainunit.math') +def float_power(x: Union[Quantity, bst.typing.ArrayLike], + y: bst.typing.ArrayLike) -> Union[Quantity, jax.Array]: + ''' + First array elements raised to powers from second array, element-wise. + + Args: + x: array_like, Quantity + y: array_like + + Returns: + Union[jax.Array, Quantity]: Quantity if the final unit is the product of the unit of `x` and the unit of `y`, else an array. + ''' + if isinstance(y, Quantity): + assert isscalar(y), f'{jnp.float_power.__name__} only supports scalar exponent' + if isinstance(x, Quantity): + return _return_check_unitless(Quantity(jnp.float_power(x.value, y), dim=x.dim ** y)) + elif isinstance(x, (jax.Array, np.ndarray)): + return jnp.float_power(x, y) + else: + raise ValueError(f'Unsupported types : {type(x)} abd {type(y)} for {jnp.float_power.__name__}') + + +@set_module_as('brainunit.math') +def remainder(x: Union[Quantity, bst.typing.ArrayLike], + y: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: + ''' + Return element-wise remainder of division. + + Args: + x: array_like, Quantity + y: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if the final unit is the remainder of the unit of `x` and the unit of `y`, else an array. + ''' + if isinstance(x, Quantity) and isinstance(y, Quantity): + return _return_check_unitless(Quantity(jnp.remainder(x.value, y.value), dim=x.dim / y.dim)) + elif isinstance(x, (jax.Array, np.ndarray)) and isinstance(y, (jax.Array, np.ndarray)): + return jnp.remainder(x, y) + elif isinstance(x, Quantity): + return _return_check_unitless(Quantity(jnp.remainder(x.value, y), dim=x.dim % y)) + elif isinstance(y, Quantity): + return _return_check_unitless(Quantity(jnp.remainder(x, y.value), dim=x % y.dim)) + else: + raise ValueError(f'Unsupported types : {type(x)} abd {type(y)} for {jnp.remainder.__name__}') diff --git a/brainunit/math/_compat_numpy_funcs_indexing.py b/brainunit/math/_compat_numpy_funcs_indexing.py new file mode 100644 index 0000000..7f8d8fc --- /dev/null +++ b/brainunit/math/_compat_numpy_funcs_indexing.py @@ -0,0 +1,166 @@ +# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from typing import (Union, Optional) + +import brainstate as bst +import jax +import jax.numpy as jnp +import numpy as np +from brainstate._utils import set_module_as + +from .._base import (Quantity, + fail_for_dimension_mismatch, + is_unitless, + ) + +__all__ = [ + + # indexing funcs + 'where', 'tril_indices', 'tril_indices_from', 'triu_indices', + 'triu_indices_from', 'take', 'select', +] + + +# indexing funcs +# -------------- +@set_module_as('brainunit.math') +def where(condition: Union[bool, bst.typing.ArrayLike], + *args: Union[Quantity, bst.typing.ArrayLike], + **kwds) -> Union[Quantity, jax.Array]: + condition = jnp.asarray(condition) + if len(args) == 0: + # nothing to do + return jnp.where(condition, *args, **kwds) + elif len(args) == 2: + # check that x and y have the same dimensions + fail_for_dimension_mismatch( + args[0], args[1], "x and y need to have the same dimensions" + ) + new_args = [] + for arg in args: + if isinstance(arg, Quantity): + new_args.append(arg.value) + if is_unitless(args[0]): + if len(new_args) == 2: + return jnp.where(condition, *new_args, **kwds) + else: + return jnp.where(condition, *args, **kwds) + else: + # as both arguments have the same unit, just use the first one's + dimensionless_args = [jnp.asarray(arg.value) if isinstance(arg, Quantity) else jnp.asarray(arg) for arg in args] + return Quantity.with_units( + jnp.where(condition, *dimensionless_args), args[0].dim + ) + else: + # illegal number of arguments + if len(args) == 1: + raise ValueError("where() takes 2 or 3 positional arguments but 1 was given") + elif len(args) > 2: + raise TypeError("where() takes 2 or 3 positional arguments but {} were given".format(len(args))) + + +tril_indices = jnp.tril_indices +tril_indices.__doc__ = ''' + Return the indices for the lower-triangle of an (n, m) array. + + Args: + n: int + m: int + k: int, optional + + Returns: + tuple[jax.Array]: tuple[array] +''' + + +@set_module_as('brainunit.math') +def tril_indices_from(arr: Union[Quantity, bst.typing.ArrayLike], + k: Optional[int] = 0) -> tuple[jax.Array, jax.Array]: + ''' + Return the indices for the lower-triangle of an (n, m) array. + + Args: + arr: array_like, Quantity + k: int, optional + + Returns: + tuple[jax.Array]: tuple[array] + ''' + if isinstance(arr, Quantity): + return jnp.tril_indices_from(arr.value, k=k) + else: + return jnp.tril_indices_from(arr, k=k) + + +triu_indices = jnp.triu_indices +triu_indices.__doc__ = ''' + Return the indices for the upper-triangle of an (n, m) array. + + Args: + n: int + m: int + k: int, optional + + Returns: + tuple[jax.Array]: tuple[array] +''' + + +@set_module_as('brainunit.math') +def triu_indices_from(arr: Union[Quantity, bst.typing.ArrayLike], + k: Optional[int] = 0) -> tuple[jax.Array, jax.Array]: + ''' + Return the indices for the upper-triangle of an (n, m) array. + + Args: + arr: array_like, Quantity + k: int, optional + + Returns: + tuple[jax.Array]: tuple[array] + ''' + if isinstance(arr, Quantity): + return jnp.triu_indices_from(arr.value, k=k) + else: + return jnp.triu_indices_from(arr, k=k) + + +@set_module_as('brainunit.math') +def take(a: Union[Quantity, bst.typing.ArrayLike], + indices: Union[Quantity, bst.typing.ArrayLike], + axis: Optional[int] = None, + mode: Optional[str] = None) -> Union[Quantity, jax.Array]: + if isinstance(a, Quantity): + return a.take(indices, axis=axis, mode=mode) + else: + return jnp.take(a, indices, axis=axis, mode=mode) + + +@set_module_as('brainunit.math') +def select(condlist: list[Union[bst.typing.ArrayLike]], + choicelist: Union[Quantity, bst.typing.ArrayLike], + default: int = 0) -> Union[Quantity, jax.Array]: + from builtins import all as origin_all + from builtins import any as origin_any + if origin_all(isinstance(choice, Quantity) for choice in choicelist): + if origin_any(choice.dim != choicelist[0].dim for choice in choicelist): + raise ValueError("All choices must have the same unit") + else: + return Quantity(jnp.select(condlist, [choice.value for choice in choicelist], default=default), + dim=choicelist[0].dim) + elif origin_all(isinstance(choice, (jax.Array, np.ndarray)) for choice in choicelist): + return jnp.select(condlist, choicelist, default=default) + else: + raise ValueError(f"Unsupported types : {type(condlist)} and {type(choicelist)} for select") diff --git a/brainunit/math/_compat_numpy_funcs_keep_unit.py b/brainunit/math/_compat_numpy_funcs_keep_unit.py new file mode 100644 index 0000000..4a6616e --- /dev/null +++ b/brainunit/math/_compat_numpy_funcs_keep_unit.py @@ -0,0 +1,832 @@ +# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from functools import wraps +from typing import (Union) + +import brainstate as bst +import jax +import jax.numpy as jnp +import numpy as np +from brainstate._utils import set_module_as + +from .._base import (Quantity, + ) + +__all__ = [ + # math funcs keep unit (unary) + 'real', 'imag', 'conj', 'conjugate', 'negative', 'positive', + 'abs', 'round', 'around', 'round_', 'rint', + 'floor', 'ceil', 'trunc', 'fix', 'sum', 'nancumsum', 'nansum', + 'cumsum', 'ediff1d', 'absolute', 'fabs', 'median', + 'nanmin', 'nanmax', 'ptp', 'average', 'mean', 'std', + 'nanmedian', 'nanmean', 'nanstd', 'diff', 'modf', + + # math funcs keep unit (binary) + 'fmod', 'mod', 'copysign', 'heaviside', + 'maximum', 'minimum', 'fmax', 'fmin', 'lcm', 'gcd', + + # math funcs keep unit (n-ary) + 'interp', 'clip', +] + + +# math funcs keep unit (unary) +# ---------------------------- + +def wrap_math_funcs_keep_unit_unary(func): + @wraps(func) + def f(x, *args, **kwargs): + if isinstance(x, Quantity): + return Quantity(func(x.value, *args, **kwargs), dim=x.dim) + elif isinstance(x, (jax.Array, np.ndarray)): + return func(x, *args, **kwargs) + else: + raise ValueError(f'Unsupported type: {type(x)} for {func.__name__}') + + f.__module__ = 'brainunit.math' + return f + + +@wrap_math_funcs_keep_unit_unary +def real(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: + return jnp.real(x) + + +@wrap_math_funcs_keep_unit_unary +def imag(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: + return jnp.imag(x) + + +@wrap_math_funcs_keep_unit_unary +def conj(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: + return jnp.conj(x) + + +@wrap_math_funcs_keep_unit_unary +def conjugate(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: + return jnp.conjugate(x) + + +@wrap_math_funcs_keep_unit_unary +def negative(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: + return jnp.negative(x) + + +@wrap_math_funcs_keep_unit_unary +def positive(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: + return jnp.positive(x) + + +@wrap_math_funcs_keep_unit_unary +def abs(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: + return jnp.abs(x) + + +@wrap_math_funcs_keep_unit_unary +def round_(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: + return jnp.round(x) + + +@wrap_math_funcs_keep_unit_unary +def around(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: + return jnp.around(x) + + +@wrap_math_funcs_keep_unit_unary +def round(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: + return jnp.round(x) + + +@wrap_math_funcs_keep_unit_unary +def rint(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: + return jnp.rint(x) + + +@wrap_math_funcs_keep_unit_unary +def floor(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: + return jnp.floor(x) + + +@wrap_math_funcs_keep_unit_unary +def ceil(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: + return jnp.ceil(x) + + +@wrap_math_funcs_keep_unit_unary +def trunc(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: + return jnp.trunc(x) + + +@wrap_math_funcs_keep_unit_unary +def fix(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: + return jnp.fix(x) + + +@wrap_math_funcs_keep_unit_unary +def sum(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: + return jnp.sum(x) + + +@wrap_math_funcs_keep_unit_unary +def nancumsum(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: + return jnp.nancumsum(x) + + +@wrap_math_funcs_keep_unit_unary +def nansum(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: + return jnp.nansum(x) + + +@wrap_math_funcs_keep_unit_unary +def cumsum(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: + return jnp.cumsum(x) + + +@wrap_math_funcs_keep_unit_unary +def ediff1d(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: + return jnp.ediff1d(x) + + +@wrap_math_funcs_keep_unit_unary +def absolute(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: + return jnp.absolute(x) + + +@wrap_math_funcs_keep_unit_unary +def fabs(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: + return jnp.fabs(x) + + +@wrap_math_funcs_keep_unit_unary +def median(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: + return jnp.median(x) + + +@wrap_math_funcs_keep_unit_unary +def nanmin(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: + return jnp.nanmin(x) + + +@wrap_math_funcs_keep_unit_unary +def nanmax(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: + return jnp.nanmax(x) + + +@wrap_math_funcs_keep_unit_unary +def ptp(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: + return jnp.ptp(x) + + +@wrap_math_funcs_keep_unit_unary +def average(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: + return jnp.average(x) + + +@wrap_math_funcs_keep_unit_unary +def mean(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: + return jnp.mean(x) + + +@wrap_math_funcs_keep_unit_unary +def std(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: + return jnp.std(x) + + +@wrap_math_funcs_keep_unit_unary +def nanmedian(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: + return jnp.nanmedian(x) + + +@wrap_math_funcs_keep_unit_unary +def nanmean(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: + return jnp.nanmean(x) + + +@wrap_math_funcs_keep_unit_unary +def nanstd(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: + return jnp.nanstd(x) + + +@wrap_math_funcs_keep_unit_unary +def diff(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: + return jnp.diff(x) + + +@wrap_math_funcs_keep_unit_unary +def modf(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: + return jnp.modf(x) + + +# docs for the functions above +real.__doc__ = ''' + Return the real part of the complex argument. + + Args: + x: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. +''' + +imag.__doc__ = ''' + Return the imaginary part of the complex argument. + + Args: + x: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. +''' + +conj.__doc__ = ''' + Return the complex conjugate of the argument. + + Args: + x: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. +''' + +conjugate.__doc__ = ''' + Return the complex conjugate of the argument. + + Args: + x: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. +''' + +negative.__doc__ = ''' + Return the negative of the argument. + + Args: + x: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. +''' + +positive.__doc__ = ''' + Return the positive of the argument. + + Args: + x: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. +''' + +abs.__doc__ = ''' + Return the absolute value of the argument. + + Args: + x: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. +''' + +round_.__doc__ = ''' + Round an array to the nearest integer. + + Args: + x: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. +''' + +around.__doc__ = ''' + Round an array to the nearest integer. + + Args: + x: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. +''' + +round.__doc__ = ''' + Round an array to the nearest integer. + + Args: + x: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. +''' + +rint.__doc__ = ''' + Round an array to the nearest integer. + + Args: + x: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. +''' + +floor.__doc__ = ''' + Return the floor of the argument. + + Args: + x: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. +''' + +ceil.__doc__ = ''' + Return the ceiling of the argument. + + Args: + x: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. +''' + +trunc.__doc__ = ''' + Return the truncated value of the argument. + + Args: + x: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. +''' + +fix.__doc__ = ''' + Return the nearest integer towards zero. + + Args: + x: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. +''' + +sum.__doc__ = ''' + Return the sum of the array elements. + + Args: + x: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. +''' + +nancumsum.__doc__ = ''' + Return the cumulative sum of the array elements, ignoring NaNs. + + Args: + x: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. +''' + +nansum.__doc__ = ''' + Return the sum of the array elements, ignoring NaNs. + + Args: + x: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. +''' + +cumsum.__doc__ = ''' + Return the cumulative sum of the array elements. + + Args: + x: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. +''' + +ediff1d.__doc__ = ''' + Return the differences between consecutive elements of the array. + + Args: + x: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. +''' + +absolute.__doc__ = ''' + Return the absolute value of the argument. + + Args: + x: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. +''' + +fabs.__doc__ = ''' + Return the absolute value of the argument. + + Args: + x: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. +''' + +median.__doc__ = ''' + Return the median of the array elements. + + Args: + x: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. +''' + +nanmin.__doc__ = ''' + Return the minimum of the array elements, ignoring NaNs. + + Args: + x: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. +''' + +nanmax.__doc__ = ''' + Return the maximum of the array elements, ignoring NaNs. + + Args: + x: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. +''' + +ptp.__doc__ = ''' + Return the range of the array elements (maximum - minimum). + + Args: + x: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. +''' + +average.__doc__ = ''' + Return the weighted average of the array elements. + + Args: + x: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. +''' + +mean.__doc__ = ''' + Return the mean of the array elements. + + Args: + x: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. +''' + +std.__doc__ = ''' + Return the standard deviation of the array elements. + + Args: + x: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. +''' + +nanmedian.__doc__ = ''' + Return the median of the array elements, ignoring NaNs. + + Args: + x: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. +''' + +nanmean.__doc__ = ''' + Return the mean of the array elements, ignoring NaNs. + + Args: + x: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. +''' + +nanstd.__doc__ = ''' + Return the standard deviation of the array elements, ignoring NaNs. + + Args: + x: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. +''' + +diff.__doc__ = ''' + Return the differences between consecutive elements of the array. + + Args: + x: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. +''' + +modf.__doc__ = ''' + Return the fractional and integer parts of the array elements. + + Args: + x: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity tuple if `x` is a Quantity, else an array tuple. +''' + + +# math funcs keep unit (binary) +# ----------------------------- + +def wrap_math_funcs_keep_unit_binary(func): + @wraps(func) + def f(x1, x2, *args, **kwargs): + if isinstance(x1, Quantity) and isinstance(x2, Quantity): + return Quantity(func(x1.value, x2.value, *args, **kwargs), dim=x1.dim) + elif isinstance(x1, (jax.Array, np.ndarray)) and isinstance(x2, (jax.Array, np.ndarray)): + return func(x1, x2, *args, **kwargs) + else: + raise ValueError(f'Unsupported type: {type(x1)} and {type(x2)} for {func.__name__}') + + f.__module__ = 'brainunit.math' + return f + + +@wrap_math_funcs_keep_unit_binary +def fmod(x1: Union[Quantity, jax.Array], x2: Union[Quantity, jax.Array]) -> Union[Quantity, jax.Array]: + return jnp.fmod(x1, x2) + + +@wrap_math_funcs_keep_unit_binary +def mod(x1: Union[Quantity, jax.Array], x2: Union[Quantity, jax.Array]) -> Union[Quantity, jax.Array]: + return jnp.mod(x1, x2) + + +@wrap_math_funcs_keep_unit_binary +def copysign(x1: Union[Quantity, jax.Array], x2: Union[Quantity, jax.Array]) -> Union[Quantity, jax.Array]: + return jnp.copysign(x1, x2) + + +@wrap_math_funcs_keep_unit_binary +def heaviside(x1: Union[Quantity, jax.Array], x2: Union[Quantity, jax.Array]) -> Union[Quantity, jax.Array]: + return jnp.heaviside(x1, x2) + + +@wrap_math_funcs_keep_unit_binary +def maximum(x1: Union[Quantity, jax.Array], x2: Union[Quantity, jax.Array]) -> Union[Quantity, jax.Array]: + return jnp.maximum(x1, x2) + + +@wrap_math_funcs_keep_unit_binary +def minimum(x1: Union[Quantity, jax.Array], x2: Union[Quantity, jax.Array]) -> Union[Quantity, jax.Array]: + return jnp.minimum(x1, x2) + + +@wrap_math_funcs_keep_unit_binary +def fmax(x1: Union[Quantity, jax.Array], x2: Union[Quantity, jax.Array]) -> Union[Quantity, jax.Array]: + return jnp.fmax(x1, x2) + + +@wrap_math_funcs_keep_unit_binary +def fmin(x1: Union[Quantity, jax.Array], x2: Union[Quantity, jax.Array]) -> Union[Quantity, jax.Array]: + return jnp.fmin(x1, x2) + + +@wrap_math_funcs_keep_unit_binary +def lcm(x1: Union[Quantity, jax.Array], x2: Union[Quantity, jax.Array]) -> Union[Quantity, jax.Array]: + return jnp.lcm(x1, x2) + + +@wrap_math_funcs_keep_unit_binary +def gcd(x1: Union[Quantity, jax.Array], x2: Union[Quantity, jax.Array]) -> Union[Quantity, jax.Array]: + return jnp.gcd(x1, x2) + + +# docs for the functions above +fmod.__doc__ = ''' + Return the element-wise remainder of division. + + Args: + x1: array_like, Quantity + x2: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if `x1` and `x2` are Quantities that have the same unit, else an array. +''' + +mod.__doc__ = ''' + Return the element-wise modulus of division. + + Args: + x1: array_like, Quantity + x2: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if `x1` and `x2` are Quantities that have the same unit, else an array. +''' + +copysign.__doc__ = ''' + Return a copy of the first array elements with the sign of the second array. + + Args: + x1: array_like, Quantity + x2: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if `x1` and `x2` are Quantities that have the same unit, else an array. +''' + +heaviside.__doc__ = ''' + Compute the Heaviside step function. + + Args: + x1: array_like, Quantity + x2: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if `x1` and `x2` are Quantities that have the same unit, else an array. +''' + +maximum.__doc__ = ''' + Element-wise maximum of array elements. + + Args: + x1: array_like, Quantity + x2: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if `x1` and `x2` are Quantities that have the same unit, else an array. +''' + +minimum.__doc__ = ''' + Element-wise minimum of array elements. + + Args: + x1: array_like, Quantity + x2: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if `x1` and `x2` are Quantities that have the same unit, else an array. +''' + +fmax.__doc__ = ''' + Element-wise maximum of array elements ignoring NaNs. + + Args: + x1: array_like, Quantity + x2: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if `x1` and `x2` are Quantities that have the same unit, else an array. +''' + +fmin.__doc__ = ''' + Element-wise minimum of array elements ignoring NaNs. + + Args: + x1: array_like, Quantity + x2: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if `x1` and `x2` are Quantities that have the same unit, else an array. +''' + +lcm.__doc__ = ''' + Return the least common multiple of `x1` and `x2`. + + Args: + x1: array_like, Quantity + x2: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if `x1` and `x2` are Quantities that have the same unit, else an array. +''' + +gcd.__doc__ = ''' + Return the greatest common divisor of `x1` and `x2`. + + Args: + x1: array_like, Quantity + x2: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if `x1` and `x2` are Quantities that have the same unit, else an array. +''' + + +# math funcs keep unit (n-ary) +# ---------------------------- +@set_module_as('brainunit.math') +def interp(x: Union[Quantity, bst.typing.ArrayLike], + xp: Union[Quantity, bst.typing.ArrayLike], + fp: Union[Quantity, bst.typing.ArrayLike], + left: Union[Quantity, bst.typing.ArrayLike] = None, + right: Union[Quantity, bst.typing.ArrayLike] = None, + period: Union[Quantity, bst.typing.ArrayLike] = None) -> Union[Quantity, jax.Array]: + ''' + One-dimensional linear interpolation. + + Args: + x: array_like, Quantity + xp: array_like, Quantity + fp: array_like, Quantity + left: array_like, Quantity, optional + right: array_like, Quantity, optional + period: array_like, Quantity, optional + + Returns: + Union[jax.Array, Quantity]: Quantity if `x`, `xp`, and `fp` are Quantities that have the same unit, else an array. + ''' + unit = None + if isinstance(x, Quantity) or isinstance(xp, Quantity) or isinstance(fp, Quantity): + unit = x.dim if isinstance(x, Quantity) else xp.dim if isinstance(xp, Quantity) else fp.dim + if isinstance(x, Quantity): + x_value = x.value + else: + x_value = x + if isinstance(xp, Quantity): + xp_value = xp.value + else: + xp_value = xp + if isinstance(fp, Quantity): + fp_value = fp.value + else: + fp_value = fp + result = jnp.interp(x_value, xp_value, fp_value, left=left, right=right, period=period) + if unit is not None: + return Quantity(result, dim=unit) + else: + return result + + +@set_module_as('brainunit.math') +def clip(a: Union[Quantity, bst.typing.ArrayLike], + a_min: Union[Quantity, bst.typing.ArrayLike], + a_max: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: + ''' + Clip (limit) the values in an array. + + Args: + a: array_like, Quantity + a_min: array_like, Quantity + a_max: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if `a`, `a_min`, and `a_max` are Quantities that have the same unit, else an array. + ''' + unit = None + if isinstance(a, Quantity) or isinstance(a_min, Quantity) or isinstance(a_max, Quantity): + unit = a.dim if isinstance(a, Quantity) else a_min.dim if isinstance(a_min, Quantity) else a_max.dim + if isinstance(a, Quantity): + a_value = a.value + else: + a_value = a + if isinstance(a_min, Quantity): + a_min_value = a_min.value + else: + a_min_value = a_min + if isinstance(a_max, Quantity): + a_max_value = a_max.value + else: + a_max_value = a_max + result = jnp.clip(a_value, a_min_value, a_max_value) + if unit is not None: + return Quantity(result, dim=unit) + else: + return result diff --git a/brainunit/math/_compat_numpy_funcs_logic.py b/brainunit/math/_compat_numpy_funcs_logic.py new file mode 100644 index 0000000..e7d69e7 --- /dev/null +++ b/brainunit/math/_compat_numpy_funcs_logic.py @@ -0,0 +1,343 @@ +# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from functools import wraps +from typing import (Union, Optional) + +import brainstate as bst +import jax +import jax.numpy as jnp +import numpy as np +from jax import Array + +from .._base import (Quantity, + fail_for_dimension_mismatch, + ) + +__all__ = [ + # logic funcs (unary) + 'all', 'any', 'logical_not', + + # logic funcs (binary) + 'equal', 'not_equal', 'greater', 'greater_equal', 'less', 'less_equal', + 'array_equal', 'isclose', 'allclose', 'logical_and', + 'logical_or', 'logical_xor', "alltrue", 'sometrue', +] + + +# logic funcs (unary) +# ------------------- + +def wrap_logic_func_unary(func): + @wraps(func) + def f(x, *args, **kwargs): + if isinstance(x, Quantity): + raise ValueError(f'Expected booleans, got {x}') + elif isinstance(x, (jax.Array, np.ndarray)): + return func(x, *args, **kwargs) + else: + raise ValueError(f'Unsupported types {type(x)} for {func.__name__}') + + f.__module__ = 'brainunit.math' + return f + + +@wrap_logic_func_unary +def all(x: Union[Quantity, bst.typing.ArrayLike], axis: Optional[int] = None, + out: Optional[Array] = None, keepdims: bool = False, + where: Optional[Array] = None) -> Union[bool, Array]: + return jnp.all(x, axis=axis, out=out, keepdims=keepdims, where=where) + + +@wrap_logic_func_unary +def any(x: Union[Quantity, bst.typing.ArrayLike], axis: Optional[int] = None, + out: Optional[Array] = None, keepdims: bool = False, + where: Optional[Array] = None) -> Union[bool, Array]: + return jnp.any(x, axis=axis, out=out, keepdims=keepdims, where=where) + + +@wrap_logic_func_unary +def logical_not(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[bool, Array]: + return jnp.logical_not(x) + + +alltrue = all +sometrue = any + +# docs for functions above +all.__doc__ = ''' + Test whether all array elements along a given axis evaluate to True. + + Args: + a: array_like + axis: int, optional + out: array, optional + keepdims: bool, optional + where: array_like of bool, optional + + Returns: + Union[bool, jax.Array]: bool or array +''' + +any.__doc__ = ''' + Test whether any array element along a given axis evaluates to True. + + Args: + a: array_like + axis: int, optional + out: array, optional + keepdims: bool, optional + where: array_like of bool, optional + + Returns: + Union[bool, jax.Array]: bool or array +''' + +logical_not.__doc__ = ''' + Compute the truth value of NOT x element-wise. + + Args: + x: array_like + out: array, optional + + Returns: + Union[bool, jax.Array]: bool or array +''' + + +# logic funcs (binary) +# -------------------- + +def wrap_logic_func_binary(func): + @wraps(func) + def f(x, y, *args, **kwargs): + if isinstance(x, Quantity) and isinstance(y, Quantity): + fail_for_dimension_mismatch(x, y) + return func(x.value, y.value, *args, **kwargs) + elif isinstance(x, (jax.Array, np.ndarray)) and isinstance(y, (jax.Array, np.ndarray)): + return func(x, y, *args, **kwargs) + else: + raise ValueError(f'Unsupported types {type(x)} and {type(y)} for {func.__name__}') + + f.__module__ = 'brainunit.math' + return f + + +@wrap_logic_func_binary +def equal(x: Union[Quantity, bst.typing.ArrayLike], y: Union[Quantity, bst.typing.ArrayLike]) -> Union[bool, Array]: + return jnp.equal(x, y) + + +@wrap_logic_func_binary +def not_equal(x: Union[Quantity, bst.typing.ArrayLike], y: Union[Quantity, bst.typing.ArrayLike]) -> Union[bool, Array]: + return jnp.not_equal(x, y) + + +@wrap_logic_func_binary +def greater(x: Union[Quantity, bst.typing.ArrayLike], y: Union[Quantity, bst.typing.ArrayLike]) -> Union[bool, Array]: + return jnp.greater(x, y) + + +@wrap_logic_func_binary +def greater_equal(x: Union[Quantity, bst.typing.ArrayLike], y: Union[Quantity, bst.typing.ArrayLike]) -> Union[bool, Array]: + return jnp.greater_equal(x, y) + + +@wrap_logic_func_binary +def less(x: Union[Quantity, bst.typing.ArrayLike], y: Union[Quantity, bst.typing.ArrayLike]) -> Union[bool, Array]: + return jnp.less(x, y) + + +@wrap_logic_func_binary +def less_equal(x: Union[Quantity, bst.typing.ArrayLike], y: Union[Quantity, bst.typing.ArrayLike]) -> Union[bool, Array]: + return jnp.less_equal(x, y) + + +@wrap_logic_func_binary +def array_equal(x: Union[Quantity, bst.typing.ArrayLike], y: Union[Quantity, bst.typing.ArrayLike]) -> Union[ + bool, Array]: + return jnp.array_equal(x, y) + + +@wrap_logic_func_binary +def isclose(x: Union[Quantity, bst.typing.ArrayLike], y: Union[Quantity, bst.typing.ArrayLike], + rtol: float = 1e-05, atol: float = 1e-08, equal_nan: bool = False) -> Union[bool, Array]: + return jnp.isclose(x, y, rtol=rtol, atol=atol, equal_nan=equal_nan) + + +@wrap_logic_func_binary +def allclose(x: Union[Quantity, bst.typing.ArrayLike], y: Union[Quantity, bst.typing.ArrayLike], + rtol: float = 1e-05, atol: float = 1e-08, equal_nan: bool = False) -> Union[bool, Array]: + return jnp.allclose(x, y, rtol=rtol, atol=atol, equal_nan=equal_nan) + + +@wrap_logic_func_binary +def logical_and(x: Union[Quantity, bst.typing.ArrayLike], y: Union[Quantity, bst.typing.ArrayLike]) -> Union[ + bool, Array]: + return jnp.logical_and(x, y) + + +@wrap_logic_func_binary +def logical_or(x: Union[Quantity, bst.typing.ArrayLike], y: Union[Quantity, bst.typing.ArrayLike]) -> Union[ + bool, Array]: + return jnp.logical_or(x, y) + + +@wrap_logic_func_binary +def logical_xor(x: Union[Quantity, bst.typing.ArrayLike], y: Union[Quantity, bst.typing.ArrayLike]) -> Union[ + bool, Array]: + return jnp.logical_xor(x, y) + + +# docs for functions above +equal.__doc__ = ''' + Return (x == y) element-wise and have the same unit if x and y are Quantity. + + Args: + x: array_like, Quantity + y: array_like, Quantity + + Returns: + Union[bool, jax.Array]: bool or array +''' + +not_equal.__doc__ = ''' + Return (x != y) element-wise and have the same unit if x and y are Quantity. + + Args: + x: array_like, Quantity + y: array_like, Quantity + + Returns: + Union[bool, jax.Array]: bool or array +''' + +greater.__doc__ = ''' + Return (x > y) element-wise and have the same unit if x and y are Quantity. + + Args: + x: array_like, Quantity + y: array_like, Quantity + + Returns: + Union[bool, jax.Array]: bool or array +''' + +greater_equal.__doc__ = ''' + Return (x >= y) element-wise and have the same unit if x and y are Quantity. + + Args: + x: array_like, Quantity + y: array_like, Quantity + + Returns: + Union[bool, jax.Array]: bool or array +''' + +less.__doc__ = ''' + Return (x < y) element-wise and have the same unit if x and y are Quantity. + + Args: + x: array_like, Quantity + y: array_like, Quantity + + Returns: + Union[bool, jax.Array]: bool or array +''' + +less_equal.__doc__ = ''' + Return (x <= y) element-wise and have the same unit if x and y are Quantity. + + Args: + x: array_like, Quantity + y: array_like, Quantity + + Returns: + Union[bool, jax.Array]: bool or array +''' + +array_equal.__doc__ = ''' + Return True if two arrays have the same shape, elements, and units (if they are Quantity), False otherwise. + + Args: + x1: array_like, Quantity + x2: array_like, Quantity + + Returns: + Union[bool, jax.Array]: bool or array +''' + +isclose.__doc__ = ''' + Returns a boolean array where two arrays are element-wise equal within a tolerance and have the same unit if they are Quantity. + + Args: + a: array_like, Quantity + b: array_like, Quantity + rtol: float, optional + atol: float, optional + equal_nan: bool, optional + + Returns: + Union[bool, jax.Array]: bool or array +''' + +allclose.__doc__ = ''' + Returns True if the two arrays are equal within the given tolerance and have the same unit if they are Quantity; False otherwise. + + Args: + a: array_like, Quantity + b: array_like, Quantity + rtol: float, optional + atol: float, optional + equal_nan: bool, optional + + Returns: + bool: boolean result +''' + +logical_and.__doc__ = ''' + Compute the truth value of x AND y element-wise and have the same unit if x and y are Quantity. + + Args: + x: array_like + y: array_like + out: array, optional + + Returns: + Union[bool, jax.Array]: bool or array +''' + +logical_or.__doc__ = ''' + Compute the truth value of x OR y element-wise and have the same unit if x and y are Quantity. + + Args: + x: array_like + y: array_like + out: array, optional + + Returns: + Union[bool, jax.Array]: bool or array +''' + +logical_xor.__doc__ = ''' + Compute the truth value of x XOR y element-wise and have the same unit if x and y are Quantity. + + Args: + x: array_like + y: array_like + out: array, optional + + Returns: + Union[bool, jax.Array]: bool or array +''' diff --git a/brainunit/math/_compat_numpy_funcs_match_unit.py b/brainunit/math/_compat_numpy_funcs_match_unit.py new file mode 100644 index 0000000..d9926ad --- /dev/null +++ b/brainunit/math/_compat_numpy_funcs_match_unit.py @@ -0,0 +1,108 @@ +# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from functools import wraps +from typing import (Union) + +import jax +import jax.numpy as jnp +import numpy as np +from jax import Array + +from .._base import (Quantity, + fail_for_dimension_mismatch, + ) + +__all__ = [ + # math funcs match unit (binary) + 'add', 'subtract', 'nextafter', +] + + +# math funcs match unit (binary) +# ------------------------------ + +def wrap_math_funcs_match_unit_binary(func): + @wraps(func) + def f(x, y, *args, **kwargs): + if isinstance(x, Quantity) and isinstance(y, Quantity): + fail_for_dimension_mismatch(x, y) + return Quantity(func(x.value, y.value, *args, **kwargs), dim=x.dim) + elif isinstance(x, (jax.Array, np.ndarray)) and isinstance(y, (jax.Array, np.ndarray)): + return func(x, y, *args, **kwargs) + elif isinstance(x, Quantity): + if x.is_unitless: + return Quantity(func(x.value, y, *args, **kwargs), dim=x.dim) + else: + raise ValueError(f'Unsupported types : {type(x)} abd {type(y)} for {func.__name__}') + elif isinstance(y, Quantity): + if y.is_unitless: + return Quantity(func(x, y.value, *args, **kwargs), dim=y.dim) + else: + raise ValueError(f'Unsupported types : {type(x)} abd {type(y)} for {func.__name__}') + else: + raise ValueError(f'Unsupported types : {type(x)} abd {type(y)} for {func.__name__}') + + f.__module__ = 'brainunit.math' + return f + + +@wrap_math_funcs_match_unit_binary +def add(x: Union[Quantity, Array], y: Union[Quantity, Array]) -> Union[Quantity, Array]: + return jnp.add(x, y) + + +@wrap_math_funcs_match_unit_binary +def subtract(x: Union[Quantity, Array], y: Union[Quantity, Array]) -> Union[Quantity, Array]: + return jnp.subtract(x, y) + + +@wrap_math_funcs_match_unit_binary +def nextafter(x: Union[Quantity, Array], y: Union[Quantity, Array]) -> Union[Quantity, Array]: + return jnp.nextafter(x, y) + + +# docs for the functions above +add.__doc__ = ''' + Add arguments element-wise. + + Args: + x: array_like, Quantity + y: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if `x` and `y` are Quantities that have the same unit, else an array. +''' + +subtract.__doc__ = ''' + Subtract arguments element-wise. + + Args: + x: array_like, Quantity + y: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if `x` and `y` are Quantities that have the same unit, else an array. +''' + +nextafter.__doc__ = ''' + Return the next floating-point value after `x1` towards `x2`. + + Args: + x1: array_like, Quantity + x2: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if `x1` and `x2` are Quantities that have the same unit, else an array. +''' diff --git a/brainunit/math/_compat_numpy_funcs_remove_unit.py b/brainunit/math/_compat_numpy_funcs_remove_unit.py new file mode 100644 index 0000000..afea533 --- /dev/null +++ b/brainunit/math/_compat_numpy_funcs_remove_unit.py @@ -0,0 +1,191 @@ +# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from functools import wraps +from typing import (Union, Optional) + +import jax.numpy as jnp +from jax import Array + +from .._base import (Quantity, + ) + +__all__ = [ + + # math funcs remove unit (unary) + 'signbit', 'sign', 'histogram', 'bincount', + + # math funcs remove unit (binary) + 'corrcoef', 'correlate', 'cov', 'digitize', +] + + +# math funcs remove unit (unary) +# ------------------------------ +def wrap_math_funcs_remove_unit_unary(func): + @wraps(func) + def f(x, *args, **kwargs): + if isinstance(x, Quantity): + return func(x.value, *args, **kwargs) + else: + return func(x, *args, **kwargs) + + f.__module__ = 'brainunit.math' + return f + + +@wrap_math_funcs_remove_unit_unary +def signbit(x: Union[Array, Quantity]) -> Array: + return jnp.signbit(x) + + +@wrap_math_funcs_remove_unit_unary +def sign(x: Union[Array, Quantity]) -> Array: + return jnp.sign(x) + + +@wrap_math_funcs_remove_unit_unary +def histogram(x: Union[Array, Quantity]) -> tuple[Array, Array]: + return jnp.histogram(x) + + +@wrap_math_funcs_remove_unit_unary +def bincount(x: Union[Array, Quantity]) -> Array: + return jnp.bincount(x) + + +# docs for the functions above +signbit.__doc__ = ''' + Returns element-wise True where signbit is set (less than zero). + + Args: + x: array_like, Quantity + + Returns: + jax.Array: an array +''' + +sign.__doc__ = ''' + Returns the sign of each element in the input array. + + Args: + x: array_like, Quantity + + Returns: + jax.Array: an array +''' + +histogram.__doc__ = ''' + Compute the histogram of a set of data. + + Args: + x: array_like, Quantity + + Returns: + tuple[jax.Array]: Tuple of arrays (hist, bin_edges) +''' + +bincount.__doc__ = ''' + Count number of occurrences of each value in array of non-negative integers. + + Args: + x: array_like, Quantity + + Returns: + jax.Array: an array +''' + + +# math funcs remove unit (binary) +# ------------------------------- +def wrap_math_funcs_remove_unit_binary(func): + @wraps(func) + def f(x, y, *args, **kwargs): + if isinstance(x, Quantity): + x_value = x.value + if isinstance(y, Quantity): + y_value = y.value + if isinstance(x, Quantity) or isinstance(y, Quantity): + return func(jnp.array(x_value), jnp.array(y_value), *args, **kwargs) + else: + return func(x, y, *args, **kwargs) + + f.__module__ = 'brainunit.math' + return f + + +@wrap_math_funcs_remove_unit_binary +def corrcoef(x: Union[Array, Quantity], y: Union[Array, Quantity]) -> Array: + return jnp.corrcoef(x, y) + + +@wrap_math_funcs_remove_unit_binary +def correlate(x: Union[Array, Quantity], y: Union[Array, Quantity]) -> Array: + return jnp.correlate(x, y) + + +@wrap_math_funcs_remove_unit_binary +def cov(x: Union[Array, Quantity], y: Optional[Union[Array, Quantity]] = None) -> Array: + return jnp.cov(x, y) + + +@wrap_math_funcs_remove_unit_binary +def digitize(x: Union[Array, Quantity], bins: Union[Array, Quantity]) -> Array: + return jnp.digitize(x, bins) + + +# docs for the functions above +corrcoef.__doc__ = ''' + Return Pearson product-moment correlation coefficients. + + Args: + x: array_like, Quantity + y: array_like, Quantity + + Returns: + jax.Array: an array +''' + +correlate.__doc__ = ''' + Cross-correlation of two sequences. + + Args: + x: array_like, Quantity + y: array_like, Quantity + + Returns: + jax.Array: an array +''' + +cov.__doc__ = ''' + Covariance matrix. + + Args: + x: array_like, Quantity + y: array_like, Quantity (optional, if not provided, x is assumed to be a 2D array) + + Returns: + jax.Array: an array +''' + +digitize.__doc__ = ''' + Return the indices of the bins to which each value in input array belongs. + + Args: + x: array_like, Quantity + bins: array_like, Quantity + + Returns: + jax.Array: an array +''' diff --git a/brainunit/math/_compat_numpy_funcs_window.py b/brainunit/math/_compat_numpy_funcs_window.py new file mode 100644 index 0000000..776450f --- /dev/null +++ b/brainunit/math/_compat_numpy_funcs_window.py @@ -0,0 +1,69 @@ +# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from functools import wraps + +import jax.numpy as jnp +from jax import Array + +__all__ = [ + + # window funcs + 'bartlett', 'blackman', 'hamming', 'hanning', 'kaiser', +] + + +# window funcs +# ------------ + +def wrap_window_funcs(func): + @wraps(func) + def f(*args, **kwargs): + return func(*args, **kwargs) + + f.__module__ = 'brainunit.math' + return f + + +@wrap_window_funcs +def bartlett(M: int) -> Array: + return jnp.bartlett(M) + + +@wrap_window_funcs +def blackman(M: int) -> Array: + return jnp.blackman(M) + + +@wrap_window_funcs +def hamming(M: int) -> Array: + return jnp.hamming(M) + + +@wrap_window_funcs +def hanning(M: int) -> Array: + return jnp.hanning(M) + + +@wrap_window_funcs +def kaiser(M: int, beta: float) -> Array: + return jnp.kaiser(M, beta) + + +# docs for functions above +bartlett.__doc__ = jnp.bartlett.__doc__ +blackman.__doc__ = jnp.blackman.__doc__ +hamming.__doc__ = jnp.hamming.__doc__ +hanning.__doc__ = jnp.hanning.__doc__ +kaiser.__doc__ = jnp.kaiser.__doc__ diff --git a/brainunit/math/_compat_numpy_get_attribute.py b/brainunit/math/_compat_numpy_get_attribute.py new file mode 100644 index 0000000..03bec0d --- /dev/null +++ b/brainunit/math/_compat_numpy_get_attribute.py @@ -0,0 +1,215 @@ +# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from typing import (Union) + +import brainstate as bst +import jax +import jax.numpy as jnp +import numpy as np +from brainstate._utils import set_module_as + +from .._base import (Quantity, + ) + +__all__ = [ + # getting attribute funcs + 'ndim', 'isreal', 'isscalar', 'isfinite', 'isinf', + 'isnan', 'shape', 'size', +] + + +@set_module_as('brainunit.math') +def ndim(a: Union[Quantity, bst.typing.ArrayLike]) -> int: + ''' + Return the number of dimensions of an array. + + Args: + a: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: int + ''' + if isinstance(a, Quantity): + return a.ndim + else: + return jnp.ndim(a) + + +@set_module_as('brainunit.math') +def isreal(a: Union[Quantity, bst.typing.ArrayLike]) -> jax.Array: + ''' + Return True if the input array is real. + + Args: + a: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: boolean array + ''' + if isinstance(a, Quantity): + return a.isreal + else: + return jnp.isreal(a) + + +@set_module_as('brainunit.math') +def isscalar(a: Union[Quantity, bst.typing.ArrayLike]) -> bool: + ''' + Return True if the input is a scalar. + + Args: + a: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: boolean array + ''' + if isinstance(a, Quantity): + return a.isscalar + else: + return jnp.isscalar(a) + + +@set_module_as('brainunit.math') +def isfinite(a: Union[Quantity, bst.typing.ArrayLike]) -> jax.Array: + ''' + Return each element of the array is finite or not. + + Args: + a: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: boolean array + ''' + if isinstance(a, Quantity): + return a.isfinite + else: + return jnp.isfinite(a) + + +@set_module_as('brainunit.math') +def isinf(a: Union[Quantity, bst.typing.ArrayLike]) -> jax.Array: + ''' + Return each element of the array is infinite or not. + + Args: + a: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: boolean array + ''' + if isinstance(a, Quantity): + return a.isinf + else: + return jnp.isinf(a) + + +@set_module_as('brainunit.math') +def isnan(a: Union[Quantity, bst.typing.ArrayLike]) -> jax.Array: + ''' + Return each element of the array is NaN or not. + + Args: + a: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: boolean array + ''' + if isinstance(a, Quantity): + return a.isnan + else: + return jnp.isnan(a) + + +@set_module_as('brainunit.math') +def shape(a: Union[Quantity, bst.typing.ArrayLike]) -> tuple[int, ...]: + """ + Return the shape of an array. + + Parameters + ---------- + a : array_like + Input array. + + Returns + ------- + shape : tuple of ints + The elements of the shape tuple give the lengths of the + corresponding array dimensions. + + See Also + -------- + len : ``len(a)`` is equivalent to ``np.shape(a)[0]`` for N-D arrays with + ``N>=1``. + ndarray.shape : Equivalent array method. + + Examples + -------- + >>> brainunit.math.shape(brainunit.math.eye(3)) + (3, 3) + >>> brainunit.math.shape([[1, 3]]) + (1, 2) + >>> brainunit.math.shape([0]) + (1,) + >>> brainunit.math.shape(0) + () + + """ + if isinstance(a, (Quantity, jax.Array, np.ndarray)): + return a.shape + else: + return np.shape(a) + + +@set_module_as('brainunit.math') +def size(a: Union[Quantity, bst.typing.ArrayLike], axis: int = None) -> int: + """ + Return the number of elements along a given axis. + + Parameters + ---------- + a : array_like + Input data. + axis : int, optional + Axis along which the elements are counted. By default, give + the total number of elements. + + Returns + ------- + element_count : int + Number of elements along the specified axis. + + See Also + -------- + shape : dimensions of array + Array.shape : dimensions of array + Array.size : number of elements in array + + Examples + -------- + >>> a = Quantity([[1,2,3], [4,5,6]]) + >>> brainunit.math.size(a) + 6 + >>> brainunit.math.size(a, 1) + 3 + >>> brainunit.math.size(a, 0) + 2 + """ + if isinstance(a, (Quantity, jax.Array, np.ndarray)): + if axis is None: + return a.size + else: + return a.shape[axis] + else: + return np.size(a, axis=axis) diff --git a/brainunit/math/_compat_numpy_linear_algebra.py b/brainunit/math/_compat_numpy_linear_algebra.py new file mode 100644 index 0000000..88f27e9 --- /dev/null +++ b/brainunit/math/_compat_numpy_linear_algebra.py @@ -0,0 +1,149 @@ +# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from typing import (Union) + +import jax.numpy as jnp +from jax import Array + +from ._compat_numpy_funcs_change_unit import wrap_math_funcs_change_unit_binary +from ._compat_numpy_funcs_keep_unit import wrap_math_funcs_keep_unit_unary +from .._base import (Quantity, + ) + +__all__ = [ + + # linear algebra + 'dot', 'vdot', 'inner', 'outer', 'kron', 'matmul', 'trace', + +] + + + + +# linear algebra +# -------------- + +@wrap_math_funcs_change_unit_binary(lambda x, y: x * y) +def dot(a: Union[Array, Quantity], b: Union[Array, Quantity]) -> Union[Array, Quantity]: + return jnp.dot(a, b) + + +@wrap_math_funcs_change_unit_binary(lambda x, y: x * y) +def vdot(a: Union[Array, Quantity], b: Union[Array, Quantity]) -> Union[Array, Quantity]: + return jnp.vdot(a, b) + + +@wrap_math_funcs_change_unit_binary(lambda x, y: x * y) +def inner(a: Union[Array, Quantity], b: Union[Array, Quantity]) -> Union[Array, Quantity]: + return jnp.inner(a, b) + + +@wrap_math_funcs_change_unit_binary(lambda x, y: x * y) +def outer(a: Union[Array, Quantity], b: Union[Array, Quantity]) -> Union[Array, Quantity]: + return jnp.outer(a, b) + + +@wrap_math_funcs_change_unit_binary(lambda x, y: x * y) +def kron(a: Union[Array, Quantity], b: Union[Array, Quantity]) -> Union[Array, Quantity]: + return jnp.kron(a, b) + + +@wrap_math_funcs_change_unit_binary(lambda x, y: x * y) +def matmul(a: Union[Array, Quantity], b: Union[Array, Quantity]) -> Union[Array, Quantity]: + return jnp.matmul(a, b) + + +@wrap_math_funcs_keep_unit_unary +def trace(a: Union[Array, Quantity]) -> Union[Array, Quantity]: + return jnp.trace(a) + + +# docs for functions above +dot.__doc__ = ''' + Dot product of two arrays or quantities. + + Args: + a: array_like, Quantity + b: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if the final unit is the product of the unit of `x` and the unit of `y`, else an array. +''' + +vdot.__doc__ = ''' + Return the dot product of two vectors or quantities. + + Args: + a: array_like, Quantity + b: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if the final unit is the product of the unit of `a` and the unit of `b`, else an array. +''' + +inner.__doc__ = ''' + Inner product of two arrays or quantities. + + Args: + a: array_like, Quantity + b: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if the final unit is the product of the unit of `a` and the unit of `b`, else an array. +''' + +outer.__doc__ = ''' + Compute the outer product of two vectors or quantities. + + Args: + a: array_like, Quantity + b: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if the final unit is the product of the unit of `a` and the unit of `b`, else an array. +''' + +kron.__doc__ = ''' + Compute the Kronecker product of two arrays or quantities. + + Args: + a: array_like, Quantity + b: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if the final unit is the product of the unit of `a` and the unit of `b`, else an array. +''' + +matmul.__doc__ = ''' + Matrix product of two arrays or quantities. + + Args: + a: array_like, Quantity + b: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if the final unit is the product of the unit of `a` and the unit of `b`, else an array. +''' + +trace.__doc__ = ''' + Return the sum of the diagonal elements of a matrix or quantity. + + Args: + a: array_like, Quantity + offset: int, optional + + Returns: + Union[jax.Array, Quantity]: Quantity if the input is a Quantity, else an array. +''' diff --git a/brainunit/math/_compat_numpy_misc.py b/brainunit/math/_compat_numpy_misc.py new file mode 100644 index 0000000..0deb591 --- /dev/null +++ b/brainunit/math/_compat_numpy_misc.py @@ -0,0 +1,354 @@ +# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from collections.abc import Sequence +from typing import (Callable, Union, Tuple) + +import brainstate as bst +import jax +import jax.numpy as jnp +import numpy as np +import opt_einsum +from brainstate._utils import set_module_as +from jax import Array +from jax._src.numpy.lax_numpy import _einsum + +from ._compat_numpy_funcs_change_unit import wrap_math_funcs_change_unit_binary +from ._compat_numpy_funcs_keep_unit import wrap_math_funcs_keep_unit_unary +from ._utils import _compatible_with_quantity +from .._base import (DIMENSIONLESS, + Quantity, + fail_for_dimension_mismatch, + is_unitless, + get_unit, ) + +__all__ = [ + + # constants + 'e', 'pi', 'inf', + + # data types + 'dtype', 'finfo', 'iinfo', + + # more + 'broadcast_arrays', 'broadcast_shapes', + 'einsum', 'gradient', 'intersect1d', 'nan_to_num', 'nanargmax', 'nanargmin', + 'rot90', 'tensordot', +] + +# constants +# --------- +e = jnp.e +pi = jnp.pi +inf = jnp.inf + +# data types +# ---------- +dtype = jnp.dtype + + +@set_module_as('brainunit.math') +def finfo(a: Union[Quantity, bst.typing.ArrayLike]) -> jnp.finfo: + if isinstance(a, Quantity): + return jnp.finfo(a.value) + else: + return jnp.finfo(a) + + +@set_module_as('brainunit.math') +def iinfo(a: Union[Quantity, bst.typing.ArrayLike]) -> jnp.iinfo: + if isinstance(a, Quantity): + return jnp.iinfo(a.value) + else: + return jnp.iinfo(a) + + +# more +# ---- +@set_module_as('brainunit.math') +def broadcast_arrays(*args: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, list[Array]]: + from builtins import all as origin_all + from builtins import any as origin_any + if origin_all(isinstance(arg, Quantity) for arg in args): + if origin_any(arg.dim != args[0].dim for arg in args): + raise ValueError("All arguments must have the same unit") + return Quantity(jnp.broadcast_arrays(*[arg.value for arg in args]), dim=args[0].dim) + elif origin_all(isinstance(arg, (jax.Array, np.ndarray)) for arg in args): + return jnp.broadcast_arrays(*args) + else: + raise ValueError(f"Unsupported types : {type(args)} for broadcast_arrays") + + +broadcast_shapes = jnp.broadcast_shapes + + +@set_module_as('brainunit.math') +def einsum( + subscripts: str, + /, + *operands: Union[Quantity, jax.Array], + out: None = None, + optimize: Union[str, bool] = "optimal", + precision: jax.lax.PrecisionLike = None, + preferred_element_type: Union[jax.typing.DTypeLike, None] = None, + _dot_general: Callable[..., jax.Array] = jax.lax.dot_general, +) -> Union[jax.Array, Quantity]: + ''' + Evaluates the Einstein summation convention on the operands. + + Args: + subscripts: string containing axes names separated by commas. + *operands: sequence of one or more arrays or quantities corresponding to the subscripts. + optimize: determine whether to optimize the order of computation. In JAX + this defaults to ``"optimize"`` which produces optimized expressions via + the opt_einsum_ package. + precision: either ``None`` (default), which means the default precision for + the backend, a :class:`~jax.lax.Precision` enum value (``Precision.DEFAULT``, + ``Precision.HIGH`` or ``Precision.HIGHEST``). + preferred_element_type: either ``None`` (default), which means the default + accumulation type for the input types, or a datatype, indicating to + accumulate results to and return a result with that datatype. + out: unsupported by JAX + _dot_general: optionally override the ``dot_general`` callable used by ``einsum``. + This parameter is experimental, and may be removed without warning at any time. + + Returns: + array containing the result of the einstein summation. + ''' + operands = (subscripts, *operands) + if out is not None: + raise NotImplementedError("The 'out' argument to jnp.einsum is not supported.") + spec = operands[0] if isinstance(operands[0], str) else None + optimize = 'optimal' if optimize is True else optimize + + # Allow handling of shape polymorphism + non_constant_dim_types = { + type(d) for op in operands if not isinstance(op, str) + for d in np.shape(op) if not jax.core.is_constant_dim(d) + } + if not non_constant_dim_types: + contract_path = opt_einsum.contract_path + else: + from jax._src.numpy.lax_numpy import _default_poly_einsum_handler + contract_path = _default_poly_einsum_handler + + operands, contractions = contract_path( + *operands, einsum_call=True, use_blas=True, optimize=optimize) + + unit = None + for i in range(len(contractions) - 1): + if contractions[i][4] == 'False': + + fail_for_dimension_mismatch( + Quantity([], dim=unit), operands[i + 1], 'einsum' + ) + elif contractions[i][4] == 'DOT' or \ + contractions[i][4] == 'TDOT' or \ + contractions[i][4] == 'GEMM' or \ + contractions[i][4] == 'OUTER/EINSUM': + if i == 0: + if isinstance(operands[i], Quantity) and isinstance(operands[i + 1], Quantity): + unit = operands[i].dim * operands[i + 1].dim + elif isinstance(operands[i], Quantity): + unit = operands[i].dim + elif isinstance(operands[i + 1], Quantity): + unit = operands[i + 1].dim + else: + if isinstance(operands[i + 1], Quantity): + unit = unit * operands[i + 1].dim + + contractions = tuple((a, frozenset(b), c) for a, b, c, *_ in contractions) + + einsum = jax.jit(_einsum, static_argnums=(1, 2, 3, 4), inline=True) + if spec is not None: + einsum = jax.named_call(einsum, name=spec) + operands = [op.value if isinstance(op, Quantity) else op for op in operands] + r = einsum(operands, contractions, precision, # type: ignore[operator] + preferred_element_type, _dot_general) + if unit is not None: + return Quantity(r, dim=unit) + else: + return r + + +@set_module_as('brainunit.math') +def gradient( + f: Union[bst.typing.ArrayLike, Quantity], + *varargs: Union[bst.typing.ArrayLike, Quantity], + axis: Union[int, Sequence[int], None] = None, + edge_order: Union[int, None] = None, +) -> Union[jax.Array, list[jax.Array], Quantity, list[Quantity]]: + ''' + Computes the gradient of a scalar field. + + Args: + f: input array. + *varargs: list of scalar fields to compute the gradient. + axis: axis or axes along which to compute the gradient. The default is to compute the gradient along all axes. + edge_order: order of the edge used for the finite difference computation. The default is 1. + + Returns: + array containing the gradient of the scalar field. + ''' + if edge_order is not None: + raise NotImplementedError("The 'edge_order' argument to jnp.gradient is not supported.") + + if len(varargs) == 0: + if isinstance(f, Quantity) and not is_unitless(f): + return Quantity(jnp.gradient(f.value, axis=axis), dim=f.dim) + else: + return jnp.gradient(f) + elif len(varargs) == 1: + unit = get_unit(f) / get_unit(varargs[0]) + if unit is None or unit == DIMENSIONLESS: + return jnp.gradient(f, varargs[0], axis=axis) + else: + return [Quantity(r, dim=unit) for r in jnp.gradient(f.value, varargs[0].value, axis=axis)] + else: + unit_list = [get_unit(f) / get_unit(v) for v in varargs] + f = f.value if isinstance(f, Quantity) else f + varargs = [v.value if isinstance(v, Quantity) else v for v in varargs] + result_list = jnp.gradient(f, *varargs, axis=axis) + return [Quantity(r, dim=unit) if unit is not None else r for r, unit in zip(result_list, unit_list)] + + +@set_module_as('brainunit.math') +def intersect1d( + ar1: Union[bst.typing.ArrayLike], + ar2: Union[bst.typing.ArrayLike], + assume_unique: bool = False, + return_indices: bool = False +) -> Union[jax.Array, Quantity, tuple[Union[jax.Array, Quantity], jax.Array, jax.Array]]: + ''' + Find the intersection of two arrays. + + Args: + ar1: input array. + ar2: input array. + assume_unique: if True, the input arrays are both assumed to be unique. + return_indices: if True, the indices which correspond to the intersection of the two arrays are returned. + + Returns: + array containing the intersection of the two arrays. + ''' + fail_for_dimension_mismatch(ar1, ar2, 'intersect1d') + unit = None + if isinstance(ar1, Quantity): + unit = ar1.dim + ar1 = ar1.value if isinstance(ar1, Quantity) else ar1 + ar2 = ar2.value if isinstance(ar2, Quantity) else ar2 + result = jnp.intersect1d(ar1, ar2, assume_unique=assume_unique, return_indices=return_indices) + if return_indices: + if unit is not None: + return (Quantity(result[0], dim=unit), result[1], result[2]) + else: + return result + else: + if unit is not None: + return Quantity(result, dim=unit) + else: + return result + + +@wrap_math_funcs_keep_unit_unary +def nan_to_num(x: Union[bst.typing.ArrayLike, Quantity], nan: float = 0.0, posinf: float = jnp.inf, + neginf: float = -jnp.inf) -> Union[jax.Array, Quantity]: + return jnp.nan_to_num(x, nan=nan, posinf=posinf, neginf=neginf) + + +@wrap_math_funcs_keep_unit_unary +def rot90(m: Union[bst.typing.ArrayLike, Quantity], k: int = 1, axes: Tuple[int, int] = (0, 1)) -> Union[ + jax.Array, Quantity]: + return jnp.rot90(m, k=k, axes=axes) + + +@wrap_math_funcs_change_unit_binary(lambda x, y: x * y) +def tensordot(a: Union[bst.typing.ArrayLike, Quantity], b: Union[bst.typing.ArrayLike, Quantity], + axes: Union[int, Tuple[int, int]] = 2) -> Union[jax.Array, Quantity]: + return jnp.tensordot(a, b, axes=axes) + + +@_compatible_with_quantity(return_quantity=False) +def nanargmax(a: Union[bst.typing.ArrayLike, Quantity], axis: int = None) -> jax.Array: + return jnp.nanargmax(a, axis=axis) + + +@_compatible_with_quantity(return_quantity=False) +def nanargmin(a: Union[bst.typing.ArrayLike, Quantity], axis: int = None) -> jax.Array: + return jnp.nanargmin(a, axis=axis) + + +# docs for functions above +nan_to_num.__doc__ = ''' + Replace NaN with zero and infinity with large finite numbers (default behaviour) or with the numbers defined by the user using the `nan`, `posinf` and `neginf` arguments. + + Args: + x: input array. + nan: value to replace NaNs with. + posinf: value to replace positive infinity with. + neginf: value to replace negative infinity with. + + Returns: + array with NaNs replaced by zero and infinities replaced by large finite numbers. +''' + +nanargmax.__doc__ = ''' + Return the index of the maximum value in an array, ignoring NaNs. + + Args: + a: array like, Quantity. + axis: axis along which to operate. The default is to compute the index of the maximum over all the dimensions of the input array. + out: output array, optional. + keepdims: if True, the result is broadcast to the input array with the same number of dimensions. + + Returns: + index of the maximum value in the array. +''' + +nanargmin.__doc__ = ''' + Return the index of the minimum value in an array, ignoring NaNs. + + Args: + a: array like, Quantity. + axis: axis along which to operate. The default is to compute the index of the minimum over all the dimensions of the input array. + out: output array, optional. + keepdims: if True, the result is broadcast to the input array with the same number of dimensions. + + Returns: + index of the minimum value in the array. +''' + +rot90.__doc__ = ''' + Rotate an array by 90 degrees in the plane specified by axes. + + Args: + m: array like, Quantity. + k: number of times the array is rotated by 90 degrees. + axes: plane of rotation. Default is the last two axes. + + Returns: + rotated array. +''' + +tensordot.__doc__ = ''' + Compute tensor dot product along specified axes for arrays. + + Args: + a: array like, Quantity. + b: array like, Quantity. + axes: axes along which to compute the tensor dot product. + + Returns: + tensor dot product of the two arrays. +''' diff --git a/brainunit/math/_compat_numpy_test.py b/brainunit/math/_compat_numpy_test.py index 258cd85..8e39796 100644 --- a/brainunit/math/_compat_numpy_test.py +++ b/brainunit/math/_compat_numpy_test.py @@ -24,6 +24,7 @@ from brainunit import DimensionMismatchError from brainunit._base import Quantity from brainunit._unit_shortcuts import ms, mV +from brainunit._unit_common import second bst.environ.set(precision=64) @@ -31,7 +32,7 @@ def assert_quantity(q, values, unit): values = jnp.asarray(values) if isinstance(q, Quantity): - assert q.unit == unit.dim, f"Unit mismatch: {q.unit} != {unit}" + assert q.dim == unit.dim, f"Unit mismatch: {q.dim} != {unit}" assert jnp.allclose(q.value, values), f"Values do not match: {q.value} != {values}" else: assert jnp.allclose(q, values), f"Values do not match: {q} != {values}" @@ -44,6 +45,10 @@ def test_full(self): self.assertEqual(result.shape, (3,)) self.assertTrue(jnp.all(result == 4)) + q = bu.math.full(3, 4, unit=second) + self.assertEqual(q.shape, (3,)) + assert_quantity(q, result, second) + def test_eye(self): result = bu.math.eye(3) self.assertEqual(result.shape, (3, 3)) @@ -87,7 +92,7 @@ def test_full_like(self): self.assertTrue(jnp.all(result == 4)) q = [1, 2, 3] * bu.second - result_q = bu.math.full_like(q, 4 * bu.second) + result_q = bu.math.full_like(q, 4, unit=bu.second) assert_quantity(result_q, jnp.full_like(jnp.array([1, 2, 3]), 4), bu.second) def test_diag(self): @@ -97,7 +102,7 @@ def test_diag(self): self.assertTrue(jnp.all(result == jnp.diag(array))) q = [1, 2, 3] * bu.second - result_q = bu.math.diag(q) + result_q = bu.math.diag(q, unit=bu.second) assert_quantity(result_q, jnp.diag(jnp.array([1, 2, 3])), bu.second) def test_tril(self): @@ -107,7 +112,7 @@ def test_tril(self): self.assertTrue(jnp.all(result == jnp.tril(array))) q = jnp.ones((3, 3)) * bu.second - result_q = bu.math.tril(q) + result_q = bu.math.tril(q, unit=bu.second) assert_quantity(result_q, jnp.tril(jnp.ones((3, 3))), bu.second) def test_triu(self): @@ -117,7 +122,7 @@ def test_triu(self): self.assertTrue(jnp.all(result == jnp.triu(array))) q = jnp.ones((3, 3)) * bu.second - result_q = bu.math.triu(q) + result_q = bu.math.triu(q, unit=bu.second) assert_quantity(result_q, jnp.triu(jnp.ones((3, 3))), bu.second) def test_empty_like(self): @@ -1706,7 +1711,7 @@ def test_argsort(self): q = [2, 3, 1] * bu.second result_q = bu.math.argsort(q) expected_q = jnp.argsort(jnp.array([2, 3, 1])) - assert jnp.all(result_q == expected_q) + assert_quantity(result_q, expected_q, bu.second) def test_argmax(self): array = jnp.array([2, 3, 1]) @@ -1810,22 +1815,6 @@ def test_invert(self): q = [0b1100] * bu.second result_q = bu.math.invert(q) - def test_left_shift(self): - result = bu.math.left_shift(jnp.array([0b0100]), 2) - self.assertTrue(jnp.all(result == jnp.left_shift(jnp.array([0b0100]), 2))) - - with pytest.raises(ValueError): - q = [0b0100] * bu.second - result_q = bu.math.left_shift(q, 2) - - def test_right_shift(self): - result = bu.math.right_shift(jnp.array([0b0100]), 2) - self.assertTrue(jnp.all(result == jnp.right_shift(jnp.array([0b0100]), 2))) - - with pytest.raises(ValueError): - q = [0b0100] * bu.second - result_q = bu.math.right_shift(q, 2) - class TestElementwiseBitOperationsBinary(unittest.TestCase): @@ -1856,6 +1845,22 @@ def test_bitwise_xor(self): q2 = [0b1010] * bu.second result_q = bu.math.bitwise_xor(q1, q2) + def test_left_shift(self): + result = bu.math.left_shift(jnp.array([0b1100]), 2) + self.assertTrue(jnp.all(result == jnp.left_shift(jnp.array([0b1100]), 2))) + + with pytest.raises(ValueError): + q = [0b1100] * bu.second + result_q = bu.math.left_shift(q, 2) + + def test_right_shift(self): + result = bu.math.right_shift(jnp.array([0b1100]), 2) + self.assertTrue(jnp.all(result == jnp.right_shift(jnp.array([0b1100]), 2))) + + with pytest.raises(ValueError): + q = [0b1100] * bu.second + result_q = bu.math.right_shift(q, 2) + class TestLogicFuncsUnary(unittest.TestCase): def test_all(self): diff --git a/brainunit/math/_others.py b/brainunit/math/_others.py index 720edba..d316eb4 100644 --- a/brainunit/math/_others.py +++ b/brainunit/math/_others.py @@ -16,7 +16,7 @@ import brainstate as bst -from ._compat_numpy import wrap_math_funcs_only_accept_unitless_unary +from ._compat_numpy_funcs_accept_unitless import wrap_math_funcs_only_accept_unitless_unary __all__ = [ 'exprel', diff --git a/brainunit/math/_utils.py b/brainunit/math/_utils.py index b934c1f..61242e0 100644 --- a/brainunit/math/_utils.py +++ b/brainunit/math/_utils.py @@ -15,8 +15,9 @@ import functools -from typing import Callable +from typing import Callable, Union +import jax from jax.tree_util import tree_map from .._base import Quantity @@ -31,74 +32,64 @@ def _is_leaf(a): def _compatible_with_quantity( - fun: Callable, return_quantity: bool = True, - module: str = '' ): - func_to_wrap = fun.__np_wrapped__ if hasattr(fun, '__np_wrapped__') else fun - - @functools.wraps(func_to_wrap) - def new_fun(*args, **kwargs): - unit = None - if isinstance(args[0], Quantity): - unit = args[0].dim - elif isinstance(args[0], tuple): - if len(args[0]) == 1: - unit = args[0][0].dim if isinstance(args[0][0], Quantity) else None - elif len(args[0]) == 2: - # check all args[0] have the same unit - if all(isinstance(a, Quantity) for a in args[0]): - if all(a.dim == args[0][0].dim for a in args[0]): - unit = args[0][0].dim + def decorator(fun: Callable) -> Callable: + @functools.wraps(fun) + def new_fun(*args, **kwargs) -> Union[list[Quantity], Quantity, jax.Array]: + unit = None + if isinstance(args[0], Quantity): + unit = args[0].dim + elif isinstance(args[0], tuple): + if len(args[0]) == 1: + unit = args[0][0].dim if isinstance(args[0][0], Quantity) else None + elif len(args[0]) == 2: + # check all args[0] have the same unit + if all(isinstance(a, Quantity) for a in args[0]): + if all(a.dim == args[0][0].dim for a in args[0]): + unit = args[0][0].dim + else: + raise ValueError(f'Units do not match for {fun.__name__} operation.') + elif all(not isinstance(a, Quantity) for a in args[0]): + unit = None else: raise ValueError(f'Units do not match for {fun.__name__} operation.') - elif all(not isinstance(a, Quantity) for a in args[0]): - unit = None - else: - raise ValueError(f'Units do not match for {fun.__name__} operation.') - args = tree_map(_as_jax_array_, args, is_leaf=_is_leaf) - out = None - if len(kwargs): - # compatible with PyTorch syntax - if 'dim' in kwargs: - kwargs['axis'] = kwargs.pop('dim') - if 'keepdim' in kwargs: - kwargs['keepdims'] = kwargs.pop('keepdim') - # compatible with TensorFlow syntax - if 'keep_dims' in kwargs: - kwargs['keepdims'] = kwargs.pop('keep_dims') - # compatible with NumPy/PyTorch syntax - if 'out' in kwargs: - out = kwargs.pop('out') - if not isinstance(out, Quantity): - raise TypeError(f'"out" must be an instance of brainpy Array. While we got {type(out)}') - # format - kwargs = tree_map(_as_jax_array_, kwargs, is_leaf=_is_leaf) + args = tree_map(_as_jax_array_, args, is_leaf=_is_leaf) + out = None + if len(kwargs): + # compatible with PyTorch syntax + if 'dim' in kwargs: + kwargs['axis'] = kwargs.pop('dim') + if 'keepdim' in kwargs: + kwargs['keepdims'] = kwargs.pop('keepdim') + # compatible with TensorFlow syntax + if 'keep_dims' in kwargs: + kwargs['keepdims'] = kwargs.pop('keep_dims') + # compatible with NumPy/PyTorch syntax + if 'out' in kwargs: + out = kwargs.pop('out') + if not isinstance(out, Quantity): + raise TypeError(f'"out" must be an instance of brainpy Array. While we got {type(out)}') + # format + kwargs = tree_map(_as_jax_array_, kwargs, is_leaf=_is_leaf) - if not return_quantity: - unit = None + if not return_quantity: + unit = None - r = fun(*args, **kwargs) - if unit is not None: - if isinstance(r, (list, tuple)): - return [Quantity(rr, dim=unit) for rr in r] - else: - if out is None: - return Quantity(r, dim=unit) + r = fun(*args, **kwargs) + if unit is not None: + if isinstance(r, (list, tuple)): + return [Quantity(rr, dim=unit) for rr in r] else: - out.value = r - if out is None: - return r - else: - out.value = r + if out is None: + return Quantity(r, dim=unit) + else: + out.value = r + if out is None: + return r + else: + out.value = r - new_fun.__doc__ = ( - f'Similar to ``jax.numpy.{module + fun.__name__}`` function, ' - f'while it is compatible with brainpy Array/Variable. \n\n' - f'Note that this function is also compatible with:\n\n' - f'1. NumPy or PyTorch syntax when receiving ``out`` argument.\n' - f'2. PyTorch syntax when receiving ``keepdim`` or ``dim`` argument.\n' - f'3. TensorFlow syntax when receiving ``keep_dims`` argument.' - ) + return new_fun - return new_fun + return decorator diff --git a/docs/apis/brainunit.math.rst b/docs/apis/brainunit.math.rst index a6ab19c..7d3601d 100644 --- a/docs/apis/brainunit.math.rst +++ b/docs/apis/brainunit.math.rst @@ -1,9 +1,398 @@ ``brainunit.math`` module -========================== +========================= -.. currentmodule:: brainunit.math -.. automodule:: brainunit.math +.. currentmodule:: brainunit.math +.. automodule:: brainunit.math + +Array Creation +-------------- .. autosummary:: :toctree: generated/ + :nosignatures: + :template: classtemplate.rst + + full + full_like + eye + identity + diag + tri + tril + triu + empty + empty_like + ones + ones_like + zeros + zeros_like + array + asarray + arange + linspace + logspace + fill_diagonal + array_split + meshgrid + vander + + +Array Manipulation +------------------ + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: classtemplate.rst + + reshape + moveaxis + transpose + swapaxes + row_stack + concatenate + stack + vstack + hstack + dstack + column_stack + split + dsplit + hsplit + vsplit + tile + repeat + unique + append + flip + fliplr + flipud + roll + atleast_1d + atleast_2d + atleast_3d + expand_dims + squeeze + sort + argsort + argmax + argmin + argwhere + nonzero + flatnonzero + searchsorted + extract + count_nonzero + max + min + amax + amin + block + compress + diagflat + diagonal + choose + ravel + + +Functions Accepting Unitless +---------------------------- + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: classtemplate.rst + + exp + exp2 + expm1 + log + log10 + log1p + log2 + arccos + arccosh + arcsin + arcsinh + arctan + arctanh + cos + cosh + sin + sinc + sinh + tan + tanh + deg2rad + rad2deg + degrees + radians + angle + percentile + nanpercentile + quantile + nanquantile + hypot + arctan2 + logaddexp + logaddexp2 + + +Functions with Bitwise Operations +--------------------------------- + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: classtemplate.rst + + bitwise_not + invert + bitwise_and + bitwise_or + bitwise_xor + left_shift + right_shift + + +Functions Changing Unit +----------------------- + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: classtemplate.rst + + reciprocal + prod + product + nancumprod + nanprod + cumprod + cumproduct + var + nanvar + cbrt + square + frexp + sqrt + multiply + divide + power + cross + ldexp + true_divide + floor_divide + float_power + divmod + remainder + convolve + + +Indexing Functions +------------------ + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: classtemplate.rst + + where + tril_indices + tril_indices_from + triu_indices + triu_indices_from + take + select + + +Functions Keeping Unit +---------------------- + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: classtemplate.rst + + real + imag + conj + conjugate + negative + positive + abs + round + around + round_ + rint + floor + ceil + trunc + fix + sum + nancumsum + nansum + cumsum + ediff1d + absolute + fabs + median + nanmin + nanmax + ptp + average + mean + std + nanmedian + nanmean + nanstd + diff + modf + fmod + mod + copysign + heaviside + maximum + minimum + fmax + fmin + lcm + gcd + interp + clip + + +Logical Functions +----------------- + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: classtemplate.rst + + all + any + logical_not + equal + not_equal + greater + greater_equal + less + less_equal + array_equal + isclose + allclose + logical_and + logical_or + logical_xor + alltrue + sometrue + + +Functions Matching Unit +----------------------- + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: classtemplate.rst + + add + subtract + nextafter + + +Functions Removing Unit +----------------------- + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: classtemplate.rst + + signbit + sign + histogram + bincount + corrcoef + correlate + cov + digitize + + +Window Functions +---------------- + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: classtemplate.rst + + bartlett + blackman + hamming + hanning + kaiser + + +Get Attribute Functions +----------------------- + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: classtemplate.rst + + ndim + isreal + isscalar + isfinite + isinf + isnan + shape + size + + +Linear Algebra Functions +------------------------ + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: classtemplate.rst + + dot + vdot + inner + outer + kron + matmul + trace + + +More Functions +-------------- + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: classtemplate.rst + + finfo + iinfo + broadcast_arrays + broadcast_shapes + einsum + gradient + intersect1d + nan_to_num + nanargmax + nanargmin + rot90 + tensordot + dtype + e + pi + inf + diff --git a/docs/auto_generater.py b/docs/auto_generater.py index b192b88..7b76528 100644 --- a/docs/auto_generater.py +++ b/docs/auto_generater.py @@ -226,7 +226,6 @@ def _write_subsections_v4(module_path, fout.write(f'.. currentmodule:: {out_path} \n') fout.write(f'.. automodule:: {out_path} \n\n') - fout.write('.. autosummary::\n') fout.write(' :toctree: generated/\n') fout.write(' :nosignatures:\n') @@ -319,14 +318,29 @@ def _section(header, numpy_mod, brainpy_mod, jax_mod, klass=None, is_jax=False): def main(): os.makedirs('apis/auto/', exist_ok=True) - _write_module(module_name='brainunit', - filename='apis/brainunit.math.rst', - header='``brainunit.init`` module') - - -if __name__ == '__main__': - main() - + module_and_name = [ + ('_compat_numpy_array_creation', 'Array Creation'), + ('_compat_numpy_array_manipulation', 'Array Manipulation'), + ('_compat_numpy_funcs_accept_unitless', 'Functions Accepting Unitless'), + ('_compat_numpy_funcs_bit_operation', 'Functions with Bitwise Operations'), + ('_compat_numpy_funcs_change_unit', 'Functions Changing Unit'), + ('_compat_numpy_funcs_indexing', 'Indexing Functions'), + ('_compat_numpy_funcs_keep_unit', 'Functions Keeping Unit'), + ('_compat_numpy_funcs_logic', 'Logical Functions'), + ('_compat_numpy_funcs_match_unit', 'Functions Matching Unit'), + ('_compat_numpy_funcs_remove_unit', 'Functions Removing Unit'), + ('_compat_numpy_funcs_window', 'Window Functions'), + ('_compat_numpy_get_attribute', 'Get Attribute Functions'), + ('_compat_numpy_linear_algebra', 'Linear Algebra Functions'), + ('_compat_numpy_misc', 'More Functions'), + ] + _write_submodules(module_name='brainunit.math', + filename='apis/brainunit.math.rst', + header='``brainunit.math`` module', + submodule_names=[k[0] for k in module_and_name], + section_names=[k[1] for k in module_and_name]) +if __name__ == '__main__': + main()