From 7aa810bf0416a73a4b444dac8dcbde5fb5cc24af Mon Sep 17 00:00:00 2001 From: chaoming Date: Sat, 28 Jan 2023 21:36:52 +0800 Subject: [PATCH 1/3] upgrade array op structure: 1. compatibility for NumPy 2. compatibility for PyTorch 3. compatibility for TensorFlow --- brainpy/_src/math/__init__.py | 4 +- brainpy/_src/math/arraycreation.py | 101 --------- brainpy/_src/math/arrayoperation.py | 108 --------- .../{arraycompatible.py => compat_numpy.py} | 100 ++++++++- brainpy/_src/math/compat_pytorch.py | 58 +++++ brainpy/_src/math/compat_tensorflow.py | 211 ++++++++++++++++++ brainpy/_src/math/delayvars.py | 2 +- brainpy/_src/math/index_tricks.py | 2 +- .../_src/math/object_transform/controls.py | 3 - brainpy/_src/math/others.py | 32 +++ brainpy/_src/math/surrogate/compt.py | 2 +- ...rayoperation.py => test_compat_pytorch.py} | 8 +- brainpy/math/__init__.py | 6 +- brainpy/math/arrayinterporate.py | 10 + brainpy/math/arrayoperation.py | 29 --- .../{arraycompatible.py => compat_numpy.py} | 15 +- brainpy/math/compat_pytorch.py | 4 + brainpy/math/compat_tensorflow.py | 22 ++ brainpy/math/others.py | 5 +- 19 files changed, 463 insertions(+), 259 deletions(-) delete mode 100644 brainpy/_src/math/arraycreation.py delete mode 100644 brainpy/_src/math/arrayoperation.py rename brainpy/_src/math/{arraycompatible.py => compat_numpy.py} (89%) create mode 100644 brainpy/_src/math/compat_pytorch.py create mode 100644 brainpy/_src/math/compat_tensorflow.py rename brainpy/_src/math/tests/{test_arrayoperation.py => test_compat_pytorch.py} (74%) create mode 100644 brainpy/math/arrayinterporate.py delete mode 100644 brainpy/math/arrayoperation.py rename brainpy/math/{arraycompatible.py => compat_numpy.py} (96%) create mode 100644 brainpy/math/compat_pytorch.py create mode 100644 brainpy/math/compat_tensorflow.py diff --git a/brainpy/_src/math/__init__.py b/brainpy/_src/math/__init__.py index dbe499714..0ad51ab88 100644 --- a/brainpy/_src/math/__init__.py +++ b/brainpy/_src/math/__init__.py @@ -39,9 +39,9 @@ from . import activations # high-level numpy operations -from .arraycreation import * from .arrayinterporate import * -from .arraycompatible import * +from .compat_numpy import * +from .compat_tensorflow import * from .others import * from . import random, linalg, fft diff --git a/brainpy/_src/math/arraycreation.py b/brainpy/_src/math/arraycreation.py deleted file mode 100644 index d111e9c53..000000000 --- a/brainpy/_src/math/arraycreation.py +++ /dev/null @@ -1,101 +0,0 @@ -# -*- coding: utf-8 -*- - -import jax.numpy as jnp -from jax.tree_util import tree_flatten, tree_unflatten - -from ._utils import wraps -from .ndarray import Array - -__all__ = [ - 'empty', 'empty_like', 'ones', 'ones_like', 'zeros', 'zeros_like', - 'array', 'asarray', 'arange', 'linspace', 'logspace', -] - - -def _as_jax_array_(obj): - return obj.value if isinstance(obj, Array) else obj - - -@wraps(jnp.zeros) -def zeros(shape, dtype=None): - return Array(jnp.zeros(shape, dtype=dtype)) - - -@wraps(jnp.ones) -def ones(shape, dtype=None): - return Array(jnp.ones(shape, dtype=dtype)) - - -@wraps(jnp.empty) -def empty(shape, dtype=None): - return Array(jnp.zeros(shape, dtype=dtype)) - - -@wraps(jnp.zeros_like) -def zeros_like(a, dtype=None, shape=None): - a = _as_jax_array_(a) - return Array(jnp.zeros_like(a, dtype=dtype, shape=shape)) - - -@wraps(jnp.ones_like) -def ones_like(a, dtype=None, shape=None): - a = _as_jax_array_(a) - return Array(jnp.ones_like(a, dtype=dtype, shape=shape)) - - -@wraps(jnp.empty_like) -def empty_like(a, dtype=None, shape=None): - a = _as_jax_array_(a) - return Array(jnp.zeros_like(a, dtype=dtype, shape=shape)) - - -@wraps(jnp.array) -def array(a, dtype=None, copy=True, order="K", ndmin=0) -> Array: - a = _as_jax_array_(a) - try: - res = jnp.array(a, dtype=dtype, copy=copy, order=order, ndmin=ndmin) - except TypeError: - leaves, tree = tree_flatten(a, is_leaf=lambda a: isinstance(a, Array)) - leaves = [_as_jax_array_(l) for l in leaves] - a = tree_unflatten(tree, leaves) - res = jnp.array(a, dtype=dtype, copy=copy, order=order, ndmin=ndmin) - return Array(res) - - -@wraps(jnp.asarray) -def asarray(a, dtype=None, order=None): - a = _as_jax_array_(a) - try: - res = jnp.asarray(a=a, dtype=dtype, order=order) - except TypeError: - leaves, tree = tree_flatten(a, is_leaf=lambda a: isinstance(a, Array)) - leaves = [_as_jax_array_(l) for l in leaves] - arrays = tree_unflatten(tree, leaves) - res = jnp.asarray(a=arrays, dtype=dtype, order=order) - return Array(res) - - -@wraps(jnp.arange) -def arange(*args, **kwargs): - args = [_as_jax_array_(a) for a in args] - kwargs = {k: _as_jax_array_(v) for k, v in kwargs.items()} - return Array(jnp.arange(*args, **kwargs)) - - -@wraps(jnp.linspace) -def linspace(*args, **kwargs): - args = [_as_jax_array_(a) for a in args] - kwargs = {k: _as_jax_array_(v) for k, v in kwargs.items()} - res = jnp.linspace(*args, **kwargs) - if isinstance(res, tuple): - return Array(res[0]), res[1] - else: - return Array(res) - - -@wraps(jnp.logspace) -def logspace(*args, **kwargs): - args = [_as_jax_array_(a) for a in args] - kwargs = {k: _as_jax_array_(v) for k, v in kwargs.items()} - return Array(jnp.logspace(*args, **kwargs)) - diff --git a/brainpy/_src/math/arrayoperation.py b/brainpy/_src/math/arrayoperation.py deleted file mode 100644 index 12f1b5d8f..000000000 --- a/brainpy/_src/math/arrayoperation.py +++ /dev/null @@ -1,108 +0,0 @@ -# -*- coding: utf-8 -*- - - -from typing import Union, Optional - -import jax -import jax.numpy as jnp -from jax.tree_util import tree_map -import numpy as np - -from .arrayinterporate import as_jax -from .ndarray import Array - -__all__ = [ - 'flatten', - 'fill_diagonal', - 'remove_diag', - 'clip_by_norm', -] - - -def flatten(input: Union[jax.Array, Array], - start_dim: Optional[int] = None, - end_dim: Optional[int] = None) -> jax.Array: - """Flattens input by reshaping it into a one-dimensional tensor. - If ``start_dim`` or ``end_dim`` are passed, only dimensions starting - with ``start_dim`` and ending with ``end_dim`` are flattened. - The order of elements in input is unchanged. - - .. note:: - Flattening a zero-dimensional tensor will return a one-dimensional view. - - Parameters - ---------- - input: Array - The input array. - start_dim: int - the first dim to flatten - end_dim: int - the last dim to flatten - - Returns - ------- - out: Array - """ - input = as_jax(input) - shape = input.shape - ndim = input.ndim - if ndim == 0: - ndim = 1 - if start_dim is None: - start_dim = 0 - elif start_dim < 0: - start_dim = ndim + start_dim - if end_dim is None: - end_dim = ndim - 1 - elif end_dim < 0: - end_dim = ndim + end_dim - end_dim += 1 - if start_dim < 0 or start_dim > ndim: - raise ValueError(f'start_dim {start_dim} is out of size.') - if end_dim < 0 or end_dim > ndim: - raise ValueError(f'end_dim {end_dim} is out of size.') - new_shape = shape[:start_dim] + (np.prod(shape[start_dim: end_dim], dtype=int), ) + shape[end_dim:] - return jnp.reshape(input, new_shape) - - -def fill_diagonal(a, val, inplace=True): - if a.ndim < 2: - raise ValueError(f'Only support tensor has dimension >= 2, but got {a.shape}') - if not isinstance(a, Array) and inplace: - raise ValueError('``fill_diagonal()`` is used in in-place updating, therefore ' - 'it requires a brainpy Array. If you want to disable ' - 'inplace updating, use ``fill_diagonal(inplace=False)``.') - val = val.value if isinstance(val, Array) else val - i, j = jnp.diag_indices(min(a.shape[-2:])) - r = as_jax(a).at[..., i, j].set(val) - if inplace: - a.value = r - else: - return r - - -def remove_diag(arr): - """Remove the diagonal of the matrix. - - Parameters - ---------- - arr: ArrayType - The matrix with the shape of `(M, N)`. - - Returns - ------- - arr: Array - The matrix without diagonal which has the shape of `(M, N-1)`. - """ - if arr.ndim != 2: - raise ValueError(f'Only support 2D matrix, while we got a {arr.ndim}D array.') - eyes = Array(jnp.ones(arr.shape, dtype=bool)) - fill_diagonal(eyes, False) - return jnp.reshape(arr[eyes.value], (arr.shape[0], arr.shape[1] - 1)) - - -def clip_by_norm(t, clip_norm, axis=None): - def f(l): - return l * clip_norm / jnp.maximum(jnp.sqrt(jnp.sum(l * l, axis=axis, keepdims=True)), clip_norm) - - return tree_map(f, t) diff --git a/brainpy/_src/math/arraycompatible.py b/brainpy/_src/math/compat_numpy.py similarity index 89% rename from brainpy/_src/math/arraycompatible.py rename to brainpy/_src/math/compat_numpy.py index 598112f06..6e4ca7fe9 100644 --- a/brainpy/_src/math/arraycompatible.py +++ b/brainpy/_src/math/compat_numpy.py @@ -3,14 +3,16 @@ import jax.numpy as jnp import numpy as np from jax.tree_util import tree_map +from jax.tree_util import tree_flatten, tree_unflatten -from ._utils import _compatible_with_brainpy_array -from .arraycreation import * +from ._utils import _compatible_with_brainpy_array, _as_jax_array_ from .arrayinterporate import * from .ndarray import Array __all__ = [ 'full', 'full_like', 'eye', 'identity', 'diag', 'tri', 'tril', 'triu', + 'empty', 'empty_like', 'ones', 'ones_like', 'zeros', 'zeros_like', + 'array', 'asarray', 'arange', 'linspace', 'logspace', 'fill_diagonal', # math funcs 'real', 'imag', 'conj', 'conjugate', 'ndim', 'isreal', 'isscalar', @@ -97,10 +99,98 @@ ] + _min = min _max = max +def fill_diagonal(a, val, inplace=True): + if a.ndim < 2: + raise ValueError(f'Only support tensor has dimension >= 2, but got {a.shape}') + if not isinstance(a, Array) and inplace: + raise ValueError('``fill_diagonal()`` is used in in-place updating, therefore ' + 'it requires a brainpy Array. If you want to disable ' + 'inplace updating, use ``fill_diagonal(inplace=False)``.') + val = val.value if isinstance(val, Array) else val + i, j = jnp.diag_indices(min(a.shape[-2:])) + r = as_jax(a).at[..., i, j].set(val) + if inplace: + a.value = r + else: + return r + +def zeros(shape, dtype=None): + return Array(jnp.zeros(shape, dtype=dtype)) + + +def ones(shape, dtype=None): + return Array(jnp.ones(shape, dtype=dtype)) + + +def empty(shape, dtype=None): + return Array(jnp.zeros(shape, dtype=dtype)) + + +def zeros_like(a, dtype=None, shape=None): + a = _as_jax_array_(a) + return Array(jnp.zeros_like(a, dtype=dtype, shape=shape)) + + +def ones_like(a, dtype=None, shape=None): + a = _as_jax_array_(a) + return Array(jnp.ones_like(a, dtype=dtype, shape=shape)) + + +def empty_like(a, dtype=None, shape=None): + a = _as_jax_array_(a) + return Array(jnp.zeros_like(a, dtype=dtype, shape=shape)) + + +def array(a, dtype=None, copy=True, order="K", ndmin=0) -> Array: + a = _as_jax_array_(a) + try: + res = jnp.array(a, dtype=dtype, copy=copy, order=order, ndmin=ndmin) + except TypeError: + leaves, tree = tree_flatten(a, is_leaf=lambda a: isinstance(a, Array)) + leaves = [_as_jax_array_(l) for l in leaves] + a = tree_unflatten(tree, leaves) + res = jnp.array(a, dtype=dtype, copy=copy, order=order, ndmin=ndmin) + return Array(res) + + +def asarray(a, dtype=None, order=None): + a = _as_jax_array_(a) + try: + res = jnp.asarray(a=a, dtype=dtype, order=order) + except TypeError: + leaves, tree = tree_flatten(a, is_leaf=lambda a: isinstance(a, Array)) + leaves = [_as_jax_array_(l) for l in leaves] + arrays = tree_unflatten(tree, leaves) + res = jnp.asarray(a=arrays, dtype=dtype, order=order) + return Array(res) + + +def arange(*args, **kwargs): + args = [_as_jax_array_(a) for a in args] + kwargs = {k: _as_jax_array_(v) for k, v in kwargs.items()} + return Array(jnp.arange(*args, **kwargs)) + + +def linspace(*args, **kwargs): + args = [_as_jax_array_(a) for a in args] + kwargs = {k: _as_jax_array_(v) for k, v in kwargs.items()} + res = jnp.linspace(*args, **kwargs) + if isinstance(res, tuple): + return Array(res[0]), res[1] + else: + return Array(res) + + +def logspace(*args, **kwargs): + args = [_as_jax_array_(a) for a in args] + kwargs = {k: _as_jax_array_(v) for k, v in kwargs.items()} + return Array(jnp.logspace(*args, **kwargs)) + def asanyarray(a, dtype=None, order=None): return asarray(a, dtype=dtype, order=order) @@ -249,7 +339,9 @@ def asfarray(a, dtype=np.float_): trunc = _compatible_with_brainpy_array(jnp.trunc) fix = _compatible_with_brainpy_array(jnp.fix) prod = _compatible_with_brainpy_array(jnp.prod) + sum = _compatible_with_brainpy_array(jnp.sum) + diff = _compatible_with_brainpy_array(jnp.diff) median = _compatible_with_brainpy_array(jnp.median) nancumprod = _compatible_with_brainpy_array(jnp.nancumprod) @@ -305,7 +397,9 @@ def asfarray(a, dtype=np.float_): logical_or = _compatible_with_brainpy_array(jnp.logical_or) logical_xor = _compatible_with_brainpy_array(jnp.logical_xor) all = _compatible_with_brainpy_array(jnp.all) + any = _compatible_with_brainpy_array(jnp.any) + alltrue = all sometrue = any @@ -356,7 +450,9 @@ def asfarray(a, dtype=np.float_): extract = _compatible_with_brainpy_array(jnp.extract) count_nonzero = _compatible_with_brainpy_array(jnp.count_nonzero) max = _compatible_with_brainpy_array(jnp.max) + min = _compatible_with_brainpy_array(jnp.min) + amax = max amin = min apply_along_axis = _compatible_with_brainpy_array(jnp.apply_along_axis) diff --git a/brainpy/_src/math/compat_pytorch.py b/brainpy/_src/math/compat_pytorch.py new file mode 100644 index 000000000..8fc5f1125 --- /dev/null +++ b/brainpy/_src/math/compat_pytorch.py @@ -0,0 +1,58 @@ +from typing import Union, Optional + +import jax +import jax.numpy as jnp +import numpy as np + +from .ndarray import Array, _as_jax_array_ + +__all__ = [ + 'flatten', +] + + +def flatten(input: Union[jax.Array, Array], + start_dim: Optional[int] = None, + end_dim: Optional[int] = None) -> jax.Array: + """Flattens input by reshaping it into a one-dimensional tensor. + If ``start_dim`` or ``end_dim`` are passed, only dimensions starting + with ``start_dim`` and ending with ``end_dim`` are flattened. + The order of elements in input is unchanged. + + .. note:: + Flattening a zero-dimensional tensor will return a one-dimensional view. + + Parameters + ---------- + input: Array + The input array. + start_dim: int + the first dim to flatten + end_dim: int + the last dim to flatten + + Returns + ------- + out: Array + """ + input = _as_jax_array_(input) + shape = input.shape + ndim = input.ndim + if ndim == 0: + ndim = 1 + if start_dim is None: + start_dim = 0 + elif start_dim < 0: + start_dim = ndim + start_dim + if end_dim is None: + end_dim = ndim - 1 + elif end_dim < 0: + end_dim = ndim + end_dim + end_dim += 1 + if start_dim < 0 or start_dim > ndim: + raise ValueError(f'start_dim {start_dim} is out of size.') + if end_dim < 0 or end_dim > ndim: + raise ValueError(f'end_dim {end_dim} is out of size.') + new_shape = shape[:start_dim] + (np.prod(shape[start_dim: end_dim], dtype=int), ) + shape[end_dim:] + return jnp.reshape(input, new_shape) + diff --git a/brainpy/_src/math/compat_tensorflow.py b/brainpy/_src/math/compat_tensorflow.py new file mode 100644 index 000000000..c46fabe21 --- /dev/null +++ b/brainpy/_src/math/compat_tensorflow.py @@ -0,0 +1,211 @@ +import jax.numpy as jnp +import jax.ops + +from .ndarray import _return, _as_jax_array_ +from .compat_numpy import prod, min, sum, all, any, mean, std, var + +__all__ = [ + 'reduce_sum', 'reduce_max', 'reduce_min', 'reduce_mean', 'reduce_all', + 'reduce_any', 'reduce_logsumexp', 'reduce_prod', 'reduce_std', 'reduce_variance', + 'reduce_euclidean_norm', + 'unsorted_segment_sqrt_n', 'segment_mean', 'unsorted_segment_sum', + 'unsorted_segment_prod', 'unsorted_segment_max', 'unsorted_segment_min', + 'unsorted_segment_mean', +] + + +def reduce_logsumexp(input_tensor, axis=None, keep_dims=False): + """Computes log(sum(exp(elements across dimensions of a tensor))). + + Reduces `input_tensor` along the dimensions given in `axis`. + + Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each + of the entries in `axis`, which must be unique. If `keepdims` is true, the + reduced dimensions are retained with length 1. + + If `axis` has no entries, all dimensions are reduced, and a + tensor with a single element is returned. + + This function is more numerically stable than log(sum(exp(input))). It avoids + overflows caused by taking the exp of large inputs and underflows caused by + taking the log of small inputs. + + Args: + input_tensor: The tensor to reduce. Should have numeric type. + axis: The dimensions to reduce. If `None` (the default), reduces all + dimensions. Must be in the range `[-rank(input_tensor), + rank(input_tensor))`. + keep_dims: If true, retains reduced dimensions with length 1. + + Returns: + The reduced tensor. + """ + r = jnp.log(jnp.sum(jnp.exp(_as_jax_array_(input_tensor)), axis=axis, keep_dims=keep_dims)) + return _return(r) + + +def reduce_euclidean_norm(input_tensor, axis=None, keep_dims=False): + """Computes the Euclidean norm of elements across dimensions of a tensor. + Reduces `input_tensor` along the dimensions given in `axis`. + + Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each + of the entries in `axis`, which must be unique. If `keepdims` is true, the + reduced dimensions are retained with length 1. + + If `axis` is None, all dimensions are reduced, and a + tensor with a single element is returned. + + Args: + input_tensor: The tensor to reduce. Should have numeric type. + axis: The dimensions to reduce. If `None` (the default), reduces all + dimensions. Must be in the range `[-rank(input_tensor), + rank(input_tensor))`. + keep_dims: If true, retains reduced dimensions with length 1. + + Returns: + The reduced tensor, of the same dtype as the input_tensor. + """ + r = jnp.linalg.norm(_as_jax_array_(input_tensor), axis=axis, keep_dims=keep_dims) + return _return(r) + + +def reduce_max(input_tensor, axis=None, keep_dims=False): + """Computes `maximum` of elements across dimensions of a tensor. + + This is the reduction operation for the elementwise `maximum` op. + Reduces `input_tensor` along the dimensions given in `axis`. + + Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each + of the entries in `axis`, which must be unique. If `keepdims` is true, the + reduced dimensions are retained with length 1. + + If `axis` is None, all dimensions are reduced, and a + tensor with a single element is returned. + + Args: + input_tensor: The tensor to reduce. Should have real numeric type. + axis: The dimensions to reduce. If `None` (the default), reduces all + dimensions. Must be in the range `[-rank(input_tensor), + rank(input_tensor))`. + keep_dims: If true, retains reduced dimensions with length 1. + + Returns: + The reduced tensor. + """ + return _return(jnp.max(_as_jax_array_(input_tensor), axis=axis, keep_dims=keep_dims)) + + +reduce_prod = prod +reduce_sum = sum +reduce_all = all +reduce_any = any +reduce_min = min +reduce_mean = mean +reduce_std = std +reduce_variance = var + + + +def segment_mean(data, segment_ids): + """Computes the average along segments of a tensor. + + See https://tensorflow.google.cn/api_docs/python/tf/math/segment_mean + + """ + r = jax.ops.segment_sum(_as_jax_array_(data), + _as_jax_array_(segment_ids), + indices_are_sorted=True) + d = jax.ops.segment_sum(jnp.ones_like(data), + _as_jax_array_(segment_ids), + indices_are_sorted=True) + return _return(jnp.nan_to_num(r / d)) + + +def unsorted_segment_sum(data, segment_ids, num_segments): + """Computes the sum along segments of a tensor. + + See https://tensorflow.google.cn/api_docs/python/tf/math/unsorted_segment_sum + + """ + r = jax.ops.segment_sum(_as_jax_array_(data), + _as_jax_array_(segment_ids), + num_segments=num_segments, + indices_are_sorted=True) + return _return(r) + + +def unsorted_segment_prod(data, segment_ids, num_segments): + """Computes the product along segments of a tensor. + + See https://tensorflow.google.cn/api_docs/python/tf/math/unsorted_segment_prod + + """ + r = jax.ops.segment_prod(_as_jax_array_(data), + _as_jax_array_(segment_ids), + num_segments=num_segments, + indices_are_sorted=True) + return _return(r) + + +def unsorted_segment_max(data, segment_ids, num_segments): + """Computes the maximum along segments of a tensor. + + See https://tensorflow.google.cn/api_docs/python/tf/math/unsorted_segment_max + + """ + r = jax.ops.segment_max(_as_jax_array_(data), + _as_jax_array_(segment_ids), + num_segments=num_segments, + indices_are_sorted=True) + return _return(r) + + + +def unsorted_segment_min(data, segment_ids, num_segments): + """Computes the minimum along segments of a tensor. + + See https://tensorflow.google.cn/api_docs/python/tf/math/unsorted_segment_min + + """ + r = jax.ops.segment_min(_as_jax_array_(data), + _as_jax_array_(segment_ids), + num_segments=num_segments, + indices_are_sorted=True) + return _return(r) + + + +def unsorted_segment_sqrt_n(data, segment_ids, num_segments): + """Computes the sum along segments of a tensor divided by the sqrt(N). + + See https://tensorflow.google.cn/api_docs/python/tf/math/unsorted_segment_sqrt_n + + """ + r = jax.ops.segment_sum(_as_jax_array_(data), + _as_jax_array_(segment_ids), + num_segments=num_segments, + indices_are_sorted=True) + d = jax.ops.segment_sum(jnp.ones_like(data), + _as_jax_array_(segment_ids), + num_segments=num_segments, + indices_are_sorted=True) + return _return(jnp.nan_to_num(r / jnp.sqrt(d))) + + +def unsorted_segment_mean(data, segment_ids, num_segments): + """Computes the average along segments of a tensor. + + See https://tensorflow.google.cn/api_docs/python/tf/math/unsorted_segment_mean + + """ + r = jax.ops.segment_sum(_as_jax_array_(data), + _as_jax_array_(segment_ids), + num_segments=num_segments, + indices_are_sorted=True) + d = jax.ops.segment_sum(jnp.ones_like(data), + _as_jax_array_(segment_ids), + num_segments=num_segments, + indices_are_sorted=True) + return _return(jnp.nan_to_num(r / d)) + + diff --git a/brainpy/_src/math/delayvars.py b/brainpy/_src/math/delayvars.py index bc028396b..889e43828 100644 --- a/brainpy/_src/math/delayvars.py +++ b/brainpy/_src/math/delayvars.py @@ -13,7 +13,7 @@ from .environment import get_dt, get_float from .ndarray import ndarray, Variable, Array from .arrayinterporate import as_jax -from . import arraycompatible as bm +from . import compat_numpy as bm __all__ = [ 'AbstractDelay', diff --git a/brainpy/_src/math/index_tricks.py b/brainpy/_src/math/index_tricks.py index 8709d4073..d10b0d0e5 100644 --- a/brainpy/_src/math/index_tricks.py +++ b/brainpy/_src/math/index_tricks.py @@ -1,7 +1,7 @@ import abc from jax import core -from .arraycompatible import arange, array, concatenate, expand_dims, linspace, meshgrid, stack, transpose +from .compat_numpy import arange, array, concatenate, expand_dims, linspace, meshgrid, stack, transpose import numpy as np __all__ = ["c_", "index_exp", "mgrid", "ogrid", "r_", "s_"] diff --git a/brainpy/_src/math/object_transform/controls.py b/brainpy/_src/math/object_transform/controls.py index c2e6df2b0..e55008fe8 100644 --- a/brainpy/_src/math/object_transform/controls.py +++ b/brainpy/_src/math/object_transform/controls.py @@ -783,7 +783,6 @@ def fun2scan(carry, x): name = get_unique_name('_brainpy_object_oriented_for_loop_') # functions - # init_vals = [v.value for v in dyn_vars] try: add_context(name) dyn_vals, out_vals = lax.scan(f=fun2scan, @@ -794,11 +793,9 @@ def fun2scan(carry, x): del_context(name) except UnexpectedTracerError as e: del_context(name) - # for v, d in zip(dyn_vars, init_vals): v._value = d raise errors.JaxTracerError() from e except Exception as e: del_context(name) - # for v, d in zip(dyn_vars, init_vals): v._value = d raise e else: for v, d in zip(dyn_vars, dyn_vals): v._value = d diff --git a/brainpy/_src/math/others.py b/brainpy/_src/math/others.py index 8406ec332..31e97df88 100644 --- a/brainpy/_src/math/others.py +++ b/brainpy/_src/math/others.py @@ -4,12 +4,17 @@ from typing import Optional import jax.numpy as jnp +from jax.tree_util import tree_map from brainpy import check, tools from .environment import get_dt, get_int +from .ndarray import Array +from .compat_numpy import fill_diagonal __all__ = [ 'shared_args_over_time', + 'remove_diag', + 'clip_by_norm', ] @@ -50,3 +55,30 @@ def shared_args_over_time(num_step: Optional[int] = None, if include_dt: r['dt'] = jnp.ones_like(r['t']) * dt return r + + +def remove_diag(arr): + """Remove the diagonal of the matrix. + + Parameters + ---------- + arr: ArrayType + The matrix with the shape of `(M, N)`. + + Returns + ------- + arr: Array + The matrix without diagonal which has the shape of `(M, N-1)`. + """ + if arr.ndim != 2: + raise ValueError(f'Only support 2D matrix, while we got a {arr.ndim}D array.') + eyes = Array(jnp.ones(arr.shape, dtype=bool)) + fill_diagonal(eyes, False) + return jnp.reshape(arr[eyes.value], (arr.shape[0], arr.shape[1] - 1)) + + +def clip_by_norm(t, clip_norm, axis=None): + def f(l): + return l * clip_norm / jnp.maximum(jnp.sqrt(jnp.sum(l * l, axis=axis, keepdims=True)), clip_norm) + + return tree_map(f, t) diff --git a/brainpy/_src/math/surrogate/compt.py b/brainpy/_src/math/surrogate/compt.py index b089f0b0d..0cfa108d1 100644 --- a/brainpy/_src/math/surrogate/compt.py +++ b/brainpy/_src/math/surrogate/compt.py @@ -4,7 +4,7 @@ from jax import custom_gradient, numpy as jnp -from brainpy._src.math.arraycreation import asarray +from brainpy._src.math.compat_numpy import asarray from brainpy._src.math.arrayinterporate import as_jax from brainpy._src.math.environment import get_float from brainpy._src.math.ndarray import Array diff --git a/brainpy/_src/math/tests/test_arrayoperation.py b/brainpy/_src/math/tests/test_compat_pytorch.py similarity index 74% rename from brainpy/_src/math/tests/test_arrayoperation.py rename to brainpy/_src/math/tests/test_compat_pytorch.py index 5f86ebabb..83e8823e9 100644 --- a/brainpy/_src/math/tests/test_arrayoperation.py +++ b/brainpy/_src/math/tests/test_compat_pytorch.py @@ -6,7 +6,7 @@ import jax.numpy as jnp import unittest import brainpy.math as bm -from brainpy._src.math import arrayoperation +from brainpy._src.math import compat_pytorch from absl .testing import parameterized @@ -15,15 +15,15 @@ class TestFlatten(unittest.TestCase): def test1(self): rng = bm.random.default_rng(113) arr = rng.rand(3, 4, 5) - a2 = arrayoperation.flatten(arr, 1, 2) + a2 = compat_pytorch.flatten(arr, 1, 2) self.assertTrue(a2.shape == (3, 20)) - a2 = arrayoperation.flatten(arr, 0, 1) + a2 = compat_pytorch.flatten(arr, 0, 1) self.assertTrue(a2.shape == (12, 5)) def test2(self): rng = bm.random.default_rng(234) arr = rng.rand() self.assertTrue(arr.ndim == 0) - arr = arrayoperation.flatten(arr) + arr = compat_pytorch.flatten(arr) self.assertTrue(arr.ndim == 1) diff --git a/brainpy/math/__init__.py b/brainpy/math/__init__.py index 83ccdcb56..29e2e31dd 100644 --- a/brainpy/math/__init__.py +++ b/brainpy/math/__init__.py @@ -4,8 +4,10 @@ # data structure from .ndarray import * from .delayvars import * -from .arrayoperation import * -from .arraycompatible import * +from .arrayinterporate import * +from .compat_numpy import * +from .compat_tensorflow import * +from .compat_pytorch import * # functions from .activations import * diff --git a/brainpy/math/arrayinterporate.py b/brainpy/math/arrayinterporate.py new file mode 100644 index 000000000..69f7d221d --- /dev/null +++ b/brainpy/math/arrayinterporate.py @@ -0,0 +1,10 @@ +# -*- coding: utf-8 -*- + +from brainpy._src.math.arrayinterporate import ( + as_device_array as as_device_array, + as_jax as as_jax, + as_ndarray as as_ndarray, + as_numpy as as_numpy, + as_variable as as_variable, +) + diff --git a/brainpy/math/arrayoperation.py b/brainpy/math/arrayoperation.py deleted file mode 100644 index 3709ee102..000000000 --- a/brainpy/math/arrayoperation.py +++ /dev/null @@ -1,29 +0,0 @@ -# -*- coding: utf-8 -*- - -from brainpy._src.math.arrayoperation import ( - flatten as flatten, - fill_diagonal as fill_diagonal, - remove_diag as remove_diag, - clip_by_norm as clip_by_norm, -) -from brainpy._src.math.arraycreation import ( - empty as empty, - empty_like as empty_like, - ones as ones, - ones_like as ones_like, - zeros as zeros, - zeros_like as zeros_like, - array as array, - asarray as asarray, - arange as arange, - linspace as linspace, - logspace as logspace, -) -from brainpy._src.math.arrayinterporate import ( - as_device_array as as_device_array, - as_jax as as_jax, - as_ndarray as as_ndarray, - as_numpy as as_numpy, - as_variable as as_variable, -) - diff --git a/brainpy/math/arraycompatible.py b/brainpy/math/compat_numpy.py similarity index 96% rename from brainpy/math/arraycompatible.py rename to brainpy/math/compat_numpy.py index 4326672a8..6ae3003ee 100644 --- a/brainpy/math/arraycompatible.py +++ b/brainpy/math/compat_numpy.py @@ -1,7 +1,18 @@ # -*- coding: utf-8 -*- - -from brainpy._src.math.arraycompatible import ( +from brainpy._src.math.compat_numpy import ( + fill_diagonal as fill_diagonal, + empty as empty, + empty_like as empty_like, + ones as ones, + ones_like as ones_like, + zeros as zeros, + zeros_like as zeros_like, + array as array, + asarray as asarray, + arange as arange, + linspace as linspace, + logspace as logspace, full as full, full_like as full_like, eye as eye, diff --git a/brainpy/math/compat_pytorch.py b/brainpy/math/compat_pytorch.py new file mode 100644 index 000000000..157a9f967 --- /dev/null +++ b/brainpy/math/compat_pytorch.py @@ -0,0 +1,4 @@ + +from brainpy._src.math.compat_pytorch import ( + flatten as flatten, +) diff --git a/brainpy/math/compat_tensorflow.py b/brainpy/math/compat_tensorflow.py new file mode 100644 index 000000000..b027ca40d --- /dev/null +++ b/brainpy/math/compat_tensorflow.py @@ -0,0 +1,22 @@ + +from brainpy._src.math.compat_tensorflow import ( + reduce_sum as reduce_sum, + reduce_max as reduce_max, + reduce_min as reduce_min, + reduce_mean as reduce_mean, + reduce_all as reduce_all, + reduce_any as reduce_any, + reduce_logsumexp as reduce_logsumexp, + reduce_prod as reduce_prod, + reduce_std as reduce_std, + reduce_variance as reduce_variance, + reduce_euclidean_norm as reduce_euclidean_norm, + unsorted_segment_sqrt_n as unsorted_segment_sqrt_n, + segment_mean as segment_mean, + unsorted_segment_sum as unsorted_segment_sum, + unsorted_segment_prod as unsorted_segment_prod, + unsorted_segment_max as unsorted_segment_max, + unsorted_segment_min as unsorted_segment_min, + unsorted_segment_mean as unsorted_segment_mean, +) + diff --git a/brainpy/math/others.py b/brainpy/math/others.py index 5e72756ef..5be7612c3 100644 --- a/brainpy/math/others.py +++ b/brainpy/math/others.py @@ -2,7 +2,6 @@ from brainpy._src.math.others import ( shared_args_over_time as shared_args_over_time, -) -from brainpy._src.math.ndarray import ( - npfun_returns_bparray as npfun_returns_bparray + remove_diag as remove_diag, + clip_by_norm as clip_by_norm, ) From a52de708c0cd09d7719cd0660994645b67fd21c1 Mon Sep 17 00:00:00 2001 From: chaoming Date: Sun, 29 Jan 2023 10:56:15 +0800 Subject: [PATCH 2/3] [compatibility] more operators in pytorch and tensorflow --- brainpy/_src/math/_utils.py | 21 ++++++++++- brainpy/_src/math/compat_numpy.py | 49 +++++++++++++++++++++++--- brainpy/_src/math/compat_pytorch.py | 32 +++++++++++++++++ brainpy/_src/math/compat_tensorflow.py | 36 ++++++++++--------- brainpy/_src/tools/others.py | 17 ++++++++- brainpy/math/compat_pytorch.py | 13 +++++++ brainpy/math/compat_tensorflow.py | 2 ++ 7 files changed, 148 insertions(+), 22 deletions(-) diff --git a/brainpy/_src/math/_utils.py b/brainpy/_src/math/_utils.py index 6c75126af..b5a856d06 100644 --- a/brainpy/_src/math/_utils.py +++ b/brainpy/_src/math/_utils.py @@ -39,10 +39,29 @@ def _compatible_with_brainpy_array(fun: Callable): @functools.wraps(fun) def new_fun(*args, **kwargs): args = tree_map(_as_jax_array_, args, is_leaf=_is_leaf) + out = None if len(kwargs): + # compatible with PyTorch syntax + if 'dim' in kwargs: + kwargs['axis'] = kwargs.pop('dim') + # compatible with PyTorch syntax + if 'keepdim' in kwargs: + kwargs['keep_dims'] = kwargs.pop('keepdim') + # compatible with TensorFlow syntax + if 'keepdims' in kwargs: + kwargs['keep_dims'] = kwargs.pop('keepdims') + # compatible with NumPy/PyTorch syntax + if 'out' in kwargs: + out = kwargs.get('out') + if not isinstance(out, Array): + raise TypeError(f'"out" must be an instance of brainpy Array. While we got {type(out)}') + # format kwargs = tree_map(_as_jax_array_, kwargs, is_leaf=_is_leaf) r = fun(*args, **kwargs) - return tree_map(_return, r) + if out is None: + return tree_map(_return, r) + else: + out.value = r new_fun.__doc__ = getattr(fun, "__doc__", None) diff --git a/brainpy/_src/math/compat_numpy.py b/brainpy/_src/math/compat_numpy.py index 6e4ca7fe9..eeb408798 100644 --- a/brainpy/_src/math/compat_numpy.py +++ b/brainpy/_src/math/compat_numpy.py @@ -1,14 +1,24 @@ # -*- coding: utf-8 -*- +from typing import (Union, Any, Protocol) + import jax.numpy as jnp import numpy as np -from jax.tree_util import tree_map from jax.tree_util import tree_flatten, tree_unflatten +from jax.tree_util import tree_map from ._utils import _compatible_with_brainpy_array, _as_jax_array_ from .arrayinterporate import * from .ndarray import Array + +class SupportsDType(Protocol): + @property + def dtype(self) -> np.dtype: ... + + +DTypeLike = Union[Any, str, np.dtype, SupportsDType] + __all__ = [ 'full', 'full_like', 'eye', 'identity', 'diag', 'tri', 'tril', 'triu', 'empty', 'empty_like', 'ones', 'ones_like', 'zeros', 'zeros_like', @@ -99,10 +109,39 @@ ] - _min = min _max = max +# def concatenate(arrays: Union[np.ndarray, Array, Sequence[Array]], +# axis: Optional[int] = None, +# dim: Optional[int] = None, +# dtype: Optional[DTypeLike] = None) -> Array: +# """Join a sequence of arrays along an existing axis. +# +# +# Parameters +# ---------- +# a1, a2, ... : sequence of array_like +# The arrays must have the same shape, except in the dimension +# corresponding to `axis` (the first, by default). +# axis : int, optional +# The axis along which the arrays will be joined. If axis is None, +# arrays are flattened before use. Default is 0. +# dtype : str or dtype +# If provided, the destination array will have this dtype. Cannot be +# provided together with `out`. +# +# Returns +# ------- +# res : ndarray +# The concatenated array. +# """ +# axis = one_of(0, axis, dim, ['axis', 'dim']) +# r = jnp.concatenate(tree_map(_as_jax_array_, arrays, is_leaf=_is_leaf), +# axis=axis, +# dtype=dtype) +# return _return(r) + def fill_diagonal(a, val, inplace=True): if a.ndim < 2: @@ -112,13 +151,14 @@ def fill_diagonal(a, val, inplace=True): 'it requires a brainpy Array. If you want to disable ' 'inplace updating, use ``fill_diagonal(inplace=False)``.') val = val.value if isinstance(val, Array) else val - i, j = jnp.diag_indices(min(a.shape[-2:])) + i, j = jnp.diag_indices(_min(a.shape[-2:])) r = as_jax(a).at[..., i, j].set(val) if inplace: a.value = r else: return r + def zeros(shape, dtype=None): return Array(jnp.zeros(shape, dtype=dtype)) @@ -191,6 +231,7 @@ def logspace(*args, **kwargs): kwargs = {k: _as_jax_array_(v) for k, v in kwargs.items()} return Array(jnp.logspace(*args, **kwargs)) + def asanyarray(a, dtype=None, order=None): return asarray(a, dtype=dtype, order=order) @@ -612,7 +653,7 @@ def common_type(*arrays): p = array_precision.get(t, None) if p is None: raise TypeError("can't get common type for non-numeric array") - precision = max(precision, p) + precision = _max(precision, p) if is_complex: return array_type[1][precision] else: diff --git a/brainpy/_src/math/compat_pytorch.py b/brainpy/_src/math/compat_pytorch.py index 8fc5f1125..83f87b312 100644 --- a/brainpy/_src/math/compat_pytorch.py +++ b/brainpy/_src/math/compat_pytorch.py @@ -5,12 +5,25 @@ import numpy as np from .ndarray import Array, _as_jax_array_ +from .compat_numpy import ( + concatenate, +) __all__ = [ + 'Tensor', 'flatten', + 'cat', + + # data types + 'bfloat16', 'half', 'float', 'double', 'cfloat', 'cdouble', 'short', 'int', 'long', 'bool' ] + +Tensor = Array +cat = concatenate + + def flatten(input: Union[jax.Array, Array], start_dim: Optional[int] = None, end_dim: Optional[int] = None) -> jax.Array: @@ -56,3 +69,22 @@ def flatten(input: Union[jax.Array, Array], new_shape = shape[:start_dim] + (np.prod(shape[start_dim: end_dim], dtype=int), ) + shape[end_dim:] return jnp.reshape(input, new_shape) +# data types +bfloat16 = jnp.bfloat16 +half = jnp.float16 +float = jnp.float32 +double = jnp.float64 +cfloat = jnp.complex64 +cdouble = jnp.complex128 +short = jnp.int16 +int = jnp.int32 +long = jnp.int64 +bool = jnp.bool_ +# missing types # +# chalf = np.complex32 +# quint8 = jnp.quint8 +# qint8 = jnp.qint8 +# qint32 = jnp.qint32 +# quint4x2 = jnp.quint4x2 + + diff --git a/brainpy/_src/math/compat_tensorflow.py b/brainpy/_src/math/compat_tensorflow.py index c46fabe21..1c1da6b0d 100644 --- a/brainpy/_src/math/compat_tensorflow.py +++ b/brainpy/_src/math/compat_tensorflow.py @@ -2,18 +2,31 @@ import jax.ops from .ndarray import _return, _as_jax_array_ -from .compat_numpy import prod, min, sum, all, any, mean, std, var +from .compat_numpy import ( + prod, min, sum, all, any, mean, std, var, concatenate, clip +) __all__ = [ - 'reduce_sum', 'reduce_max', 'reduce_min', 'reduce_mean', 'reduce_all', - 'reduce_any', 'reduce_logsumexp', 'reduce_prod', 'reduce_std', 'reduce_variance', - 'reduce_euclidean_norm', - 'unsorted_segment_sqrt_n', 'segment_mean', 'unsorted_segment_sum', - 'unsorted_segment_prod', 'unsorted_segment_max', 'unsorted_segment_min', - 'unsorted_segment_mean', + 'concat', + 'reduce_sum', 'reduce_max', 'reduce_min', 'reduce_mean', 'reduce_all', 'reduce_any', + 'reduce_logsumexp', 'reduce_prod', 'reduce_std', 'reduce_variance', 'reduce_euclidean_norm', + 'unsorted_segment_sqrt_n', 'segment_mean', 'unsorted_segment_sum', 'unsorted_segment_prod', + 'unsorted_segment_max', 'unsorted_segment_min', 'unsorted_segment_mean', + 'clip_by_value', ] +reduce_prod = prod +reduce_sum = sum +reduce_all = all +reduce_any = any +reduce_min = min +reduce_mean = mean +reduce_std = std +reduce_variance = var +concat = concatenate +clip_by_value = clip + def reduce_logsumexp(input_tensor, axis=None, keep_dims=False): """Computes log(sum(exp(elements across dimensions of a tensor))). @@ -95,15 +108,6 @@ def reduce_max(input_tensor, axis=None, keep_dims=False): return _return(jnp.max(_as_jax_array_(input_tensor), axis=axis, keep_dims=keep_dims)) -reduce_prod = prod -reduce_sum = sum -reduce_all = all -reduce_any = any -reduce_min = min -reduce_mean = mean -reduce_std = std -reduce_variance = var - def segment_mean(data, segment_ids): diff --git a/brainpy/_src/tools/others.py b/brainpy/_src/tools/others.py index 7a6fc41d0..d945d890a 100644 --- a/brainpy/_src/tools/others.py +++ b/brainpy/_src/tools/others.py @@ -3,7 +3,7 @@ import collections.abc import _thread as thread import threading -from typing import Optional, Tuple, Callable, Union, Sequence, TypeVar +from typing import Optional, Tuple, Callable, Union, Sequence, TypeVar, Any import numpy as np from jax import lax @@ -11,6 +11,7 @@ from tqdm.auto import tqdm __all__ = [ + 'one_of', 'replicate', 'not_customized', 'to_size', @@ -20,6 +21,20 @@ ] +def one_of(default: Any, *choices, names: Sequence[str] =None): + names = [f'arg{i}' for i in range(len(choices))] if names is None else names + res = default + has_chosen = False + for c in choices: + if c is not None: + if has_chosen: + raise ValueError(f'Provide one of {names}, but we got {list(zip(choices, names))}') + else: + has_chosen = True + res = c + return res + + T = TypeVar('T') diff --git a/brainpy/math/compat_pytorch.py b/brainpy/math/compat_pytorch.py index 157a9f967..6bc1b61cb 100644 --- a/brainpy/math/compat_pytorch.py +++ b/brainpy/math/compat_pytorch.py @@ -1,4 +1,17 @@ from brainpy._src.math.compat_pytorch import ( + Tensor as Tensor, flatten as flatten, + cat as cat, + + bfloat16 as bfloat16, + half as half, + float as float, + double as double, + cfloat as cfloat, + cdouble as cdouble, + short as short, + int as int, + long as long, + bool as bool, ) diff --git a/brainpy/math/compat_tensorflow.py b/brainpy/math/compat_tensorflow.py index b027ca40d..58433a364 100644 --- a/brainpy/math/compat_tensorflow.py +++ b/brainpy/math/compat_tensorflow.py @@ -1,5 +1,6 @@ from brainpy._src.math.compat_tensorflow import ( + concat as concat, reduce_sum as reduce_sum, reduce_max as reduce_max, reduce_min as reduce_min, @@ -18,5 +19,6 @@ unsorted_segment_max as unsorted_segment_max, unsorted_segment_min as unsorted_segment_min, unsorted_segment_mean as unsorted_segment_mean, + clip_by_value as clip_by_value, ) From 692719530a6ba179a3a4cd1fbf65e805ec5b9857 Mon Sep 17 00:00:00 2001 From: chaoming Date: Sun, 29 Jan 2023 11:07:51 +0800 Subject: [PATCH 3/3] fix bug --- brainpy/_src/math/compat_numpy.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/brainpy/_src/math/compat_numpy.py b/brainpy/_src/math/compat_numpy.py index eeb408798..c2c1c4b2f 100644 --- a/brainpy/_src/math/compat_numpy.py +++ b/brainpy/_src/math/compat_numpy.py @@ -1,7 +1,5 @@ # -*- coding: utf-8 -*- -from typing import (Union, Any, Protocol) - import jax.numpy as jnp import numpy as np from jax.tree_util import tree_flatten, tree_unflatten @@ -12,13 +10,6 @@ from .ndarray import Array -class SupportsDType(Protocol): - @property - def dtype(self) -> np.dtype: ... - - -DTypeLike = Union[Any, str, np.dtype, SupportsDType] - __all__ = [ 'full', 'full_like', 'eye', 'identity', 'diag', 'tri', 'tril', 'triu', 'empty', 'empty_like', 'ones', 'ones_like', 'zeros', 'zeros_like',