From 3534796d01c825b73b3d047fae70e101ec4e22c2 Mon Sep 17 00:00:00 2001 From: chaoming Date: Wed, 15 Feb 2023 13:55:14 +0800 Subject: [PATCH 1/6] fix bugs of torch compatibility in `ndarray` --- brainpy/_src/math/ndarray.py | 578 ++++++++++++++--------------------- 1 file changed, 229 insertions(+), 349 deletions(-) diff --git a/brainpy/_src/math/ndarray.py b/brainpy/_src/math/ndarray.py index 17493d79a..084ba8025 100644 --- a/brainpy/_src/math/ndarray.py +++ b/brainpy/_src/math/ndarray.py @@ -79,6 +79,11 @@ def _as_jax_array_(obj): return obj.value if isinstance(obj, Array) else obj +def _check_out(out): + if not isinstance(out, Array): + raise TypeError(f'out must be an instance of brainpy Array. But got {type(out)}') + + class Array(object): """Multiple-dimensional array in BrainPy. """ @@ -512,9 +517,16 @@ def choose(self, choices, mode='raise'): """Use an index array to construct a new array from a set of choices.""" return _return(self.value.choose(choices=_as_jax_array_(choices), mode=mode)) - def clip(self, min=None, max=None): + def clip(self, min=None, max=None, out=None, ): """Return an array whose values are limited to [min, max]. One of max or min must be given.""" - return _return(self.value.clip(min=min, max=max)) + min = _as_jax_array_(min) + max = _as_jax_array_(max) + r = self.value.clip(min=min, max=max) + if out is None: + return _return(r) + else: + _check_out(out) + out.value = r def compress(self, condition, axis=None): """Return selected slices of this array along given axis.""" @@ -999,11 +1011,11 @@ def view(self, *args, dtype=None): else: return _return(self.value.view(dtype)) else: - if isinstance(args[0], int): # shape + if isinstance(args[0], int): # shape if dtype is not None: raise ValueError('Provide one of dtype or shape. Not both.') return _return(self.value.reshape(*args)) - else: # dtype + else: # dtype assert not isinstance(args[0], int) assert dtype is None return _return(self.value.view(args[0])) @@ -1115,25 +1127,6 @@ def expand_dims(self, axis: Union[int, Sequence[int]]) -> 'Array': """ return Array(jnp.expand_dims(self.value, axis)) - def expand(self, *shape: Union[int, Sequence[int]]) -> 'Array': - """ - Expand an array to a new shape. - - Parameters - ---------- - shape : tuple or int - The shape of the desired array. A single integer ``i`` is interpreted - as ``(i,)``. - - Returns - ------- - expanded : Array - A readonly view on the original array with the given shape. It is - typically not contiguous. Furthermore, more than one element of a - expanded array may refer to a single memory location. - """ - return Array(jnp.broadcast_to(self._value, shape)) - def expand_as(self, array: Union['Array', jax.Array, np.ndarray]) -> 'Array': """ Expand an array to a shape of another array. @@ -1153,381 +1146,266 @@ def expand_as(self, array: Union['Array', jax.Array, np.ndarray]) -> 'Array': array = Array(array) return Array(jnp.broadcast_to(self.value, array.value.shape)) - def squeeze(self, - axis: Optional[Union[int, Sequence]]=None) -> 'Array': - return Array(self.squeeze(axis)) - # def item(self, *args) -> Any: # return self.value.item(*args) def pow(self, index: int): - return self._value ** index - - def addr(self, - vec1: Union['Array', jax.Array, np.ndarray], - vec2: Union['Array', jax.Array, np.ndarray], - *, - beta: float = 1.0, - alpha: float = 1.0, - out: Optional[Union['Array', jax.Array, np.ndarray]] = None) -> Union[None, NoReturn]: - if not isinstance(beta, int) and not isinstance(beta, float): - raise Exception('Wrong beta param of addr') - if not isinstance(alpha, int) and not isinstance(alpha, float): - raise Exception('Wrong alpha param of addr') - if not isinstance(vec1, Array): - vec1 = Array(vec1) - if not isinstance(vec2, Array): - vec2 = Array(vec2) - if not isinstance(out, Array): - out = Array(out) - return _return(brainpy.math.outer(vec1, vec2, out=out)) - - - def addr_(self, - vec1: Union['Array', jax.Array, np.ndarray], - vec2: Union['Array', jax.Array, np.ndarray], - *, - beta: float = 1.0, - alpha: float = 1.0) -> Union['Array', NoReturn]: - if not isinstance(beta, (int,float)): - raise Exception('Wrong beta param of addr') - if not isinstance(alpha, (int,float)): - raise Exception('Wrong alpha param of addr') - if not isinstance(vec1, Array): - vec1 = Array(vec1) - if not isinstance(vec2, Array): - vec2 = Array(vec2) - # self.value *= beta - # self.value += alpha * jnp.outer(vec1, vec2) - return brainpy.math.outer(vec1, vec2, out=self) - - def outer(self, other: Union['Array', jax.Array, np.ndarray]) -> Union[NoReturn, None]: - # if other is None: - # raise Exception('Array can not make outer product with None') - if not isinstance(other, Array): - other = Array(other) - return _return(jnp.outer(self.value, other.value)) + return _return(self._value ** index) - def sum(self) -> 'Array': - return _return(self.value.sum()) + def addr( + self, + vec1: Union['Array', jax.Array, np.ndarray], + vec2: Union['Array', jax.Array, np.ndarray], + *, + beta: float = 1.0, + alpha: float = 1.0, + out: Optional[Union['Array', jax.Array, np.ndarray]] = None + ) -> Optional['Array']: + r"""Performs the outer-product of vectors ``vec1`` and ``vec2`` and adds it to the matrix ``input``. + + Optional values beta and alpha are scaling factors on the outer product + between vec1 and vec2 and the added matrix input respectively. - def abs(self, *, out: Optional[Union['Array', jax.Array, np.ndarray]] = None) -> 'Array': - # return Array(self.value.__abs__()) - abs_value = None - if out is not None: - if not isinstance(out, Array): - out = Array(out) - abs_value = brainpy.math.abs(self.value, out=out) + .. math:: + + out = \beta \mathrm{input} + \alpha (\text{vec1} \bigtimes \text{vec2}) + + Args: + vec1: the first vector of the outer product + vec2: the second vector of the outer product + beta: multiplier for input + alpha: multiplier + out: the output tensor. + + """ + vec1 = _as_jax_array_(vec1) + vec2 = _as_jax_array_(vec2) + r = alpha * jnp.outer(vec1, vec2) + beta * self.value + if out is None: + return _return(r) + else: + _check_out(out) + out.value = r + + def addr_( + self, + vec1: Union['Array', jax.Array, np.ndarray], + vec2: Union['Array', jax.Array, np.ndarray], + *, + beta: float = 1.0, + alpha: float = 1.0 + ) -> None: + vec1 = _as_jax_array_(vec1) + vec2 = _as_jax_array_(vec2) + r = alpha * jnp.outer(vec1, vec2) + beta * self.value + self.value = r + + def outer(self, other: Union['Array', jax.Array, np.ndarray]) -> 'Array': + other = _as_jax_array_(other) + return _return(jnp.outer(self.value, other.value)) + + def abs(self, *, out: Optional[Union['Array', jax.Array, np.ndarray]] = None) -> Optional['Array']: + r = jnp.abs(self.value) + if out is None: + return _return(r) else: - abs_value = brainpy.math.abs(self.value) - # if isinstance(out, (Array, jax.Array, np.ndarray)): - # out.value = abs_value - return _return(abs_value) + _check_out(out) + out.value = r - def abs_(self) -> 'Array': + def abs_(self) -> None: """ in-place version of Array.abs() """ - return brainpy.math.abs(self, out=self) + self.value = jnp.abs(self.value) - def absolute(self, *, out: Optional[Union['Array', jax.Array, np.ndarray]] = None) -> 'Array': + def absolute(self, *, out: Optional[Union['Array', jax.Array, np.ndarray]] = None) -> Optional['Array']: """ alias of Array.abs """ - if not isinstance(out, Array): - out = Array(out) return self.abs(out=out) - def absolute_(self) -> 'Array': + def absolute_(self) -> None: """ alias of Array.abs_() """ return self.abs_() - def sin(self, *, out: Optional[Union['Array', jax.Array, np.ndarray]] = None) -> Union['Array', NoReturn]: - # return Array(self.value.__abs__()) - # if out is not None: - # if not isinstance(out, (Array, jax.Array, np.ndarray)): - # raise Exception('Unexcepted param out') - value = None - if out is not None: - if not isinstance(out, Array): - out = Array(out) - value = brainpy.math.sin(self.value, out=out) + def sin(self, *, out: Optional[Union['Array', jax.Array, np.ndarray]] = None) -> Optional['Array']: + r = jnp.sin(self.value) + if out is None: + return _return(r) else: - value = brainpy.math.sin(self.value) - return Array(value) - - def sin_(self) -> 'Array': - return Array(brainpy.math.sin(self.value, out=self)) - - def cos_(self) -> 'Array': - return Array(brainpy.math.cos(self.value, out=self)) - - def cos(self, *, out: Optional[Union['Array', jax.Array, np.ndarray]] = None) -> Union['Array', NoReturn]: - # return Array(self.value.__abs__()) - # if out is not None: - # if not isinstance(out, (Array, jax.Array, np.ndarray)): - # raise Exception('Unexcepted param out') - value = None - if out is not None: - if not isinstance(out, Array): - out = Array(out) - value = brainpy.math.cos(self.value, out=out.value) + _check_out(out) + out.value = r + + def sin_(self) -> None: + self.value = jnp.sin(self.value) + + def cos_(self) -> None: + self.value = jnp.cos(self.value) + + def cos(self, *, out: Optional[Union['Array', jax.Array, np.ndarray]] = None) -> Optional['Array']: + r = jnp.cos(self.value) + if out is None: + return _return(r) + else: + _check_out(out) + out.value = r + + def tan_(self) -> None: + self.value = jnp.tan(self.value) + + def tan(self, *, out: Optional[Union['Array', jax.Array, np.ndarray]] = None) -> Optional['Array']: + r = jnp.tan(self.value) + if out is None: + return _return(r) else: - value = brainpy.math.cos(self.value) - return Array(value) - - def tan_(self) -> 'Array': - return Array(brainpy.math.tan(self.value, out=self.value)) - - def tan(self, *, out: Optional[Union['Array', jax.Array, np.ndarray]] = None) -> Union['Array', NoReturn]: - # return Array(self.value.__abs__()) - # if out is not None: - # if not isinstance(out, (Array, jax.Array, np.ndarray)): - # raise Exception('Unexcepted param out') - value = None - if out is not None: - if not isinstance(out, Array): - out = Array(out) - value = brainpy.math.tan(self.value, out=out.value) + _check_out(out) + out.value = r + + def sinh_(self) -> None: + self.value = jnp.tanh(self.value) + + def sinh(self, *, out: Optional[Union['Array', jax.Array, np.ndarray]] = None) -> Optional['Array']: + r = jnp.tanh(self.value) + if out is None: + return _return(r) + else: + _check_out(out) + out.value = r + + def cosh_(self) -> None: + self.value = jnp.cosh(self.value) + + def cosh(self, *, out: Optional[Union['Array', jax.Array, np.ndarray]] = None) -> Optional['Array']: + r = jnp.cosh(self.value) + if out is None: + return _return(r) else: - value = brainpy.math.tan(self.value) - return Array(value) - - def sinh_(self) -> 'Array': - return Array(brainpy.math.sinh(self.value, out=self.value)) - - def sinh(self, *, out: Optional[Union['Array', jax.Array, np.ndarray]] = None) -> Union['Array', NoReturn]: - # return Array(self.value.__abs__()) - # if out is not None: - # if not isinstance(out, (Array, jax.Array, np.ndarray)): - # raise Exception('Unexcepted param out') - value = None - if out is not None: - if not isinstance(out, Array): - out = Array(out) - value = brainpy.math.sinh(self.value, out=out.value) + _check_out(out) + out.value = r + + def tanh_(self) -> None: + self.value = jnp.tanh(self.value) + + def tanh(self, *, out: Optional[Union['Array', jax.Array, np.ndarray]] = None) -> Optional['Array']: + r = jnp.tanh(self.value) + if out is None: + return _return(r) else: - value = brainpy.math.sinh(self.value) - return Array(value) - - def cosh_(self) -> 'Array': - return Array(brainpy.math.cosh(self.value, out=self.value)) - - def cosh(self, *, out: Optional[Union['Array', jax.Array, np.ndarray]] = None) -> Union['Array', NoReturn]: - # return Array(self.value.__abs__()) - # if out is not None: - # if not isinstance(out, (Array, jax.Array, np.ndarray)): - # raise Exception('Unexcepted param out') - if not isinstance(out, Array): - out = Array(out) - return Array(brainpy.math.cosh(self.value, out=out.value)) - - def tanh_(self) -> 'Array': - return Array(brainpy.math.tanh(self.value, out=self.value)) - - def tanh(self, *, out: Optional[Union['Array', jax.Array, np.ndarray]] = None) -> Union['Array', NoReturn]: - # return Array(self.value.__abs__()) - # if out is not None: - # if not isinstance(out, (Array, jax.Array, np.ndarray)): - # raise Exception('Unexcepted param out') - value = None - if out is not None: - if not isinstance(out, Array): - out = Array(out) - value = brainpy.math.tanh(self.value, out=out.value) + _check_out(out) + out.value = r + + def arcsin_(self) -> None: + self.value = jnp.arcsin(self.value) + + def arcsin(self, *, out: Optional[Union['Array', jax.Array, np.ndarray]] = None) -> Optional['Array']: + r = jnp.arcsin(self.value) + if out is None: + return _return(r) else: - value = brainpy.math.tanh(self.value) - return Array(value) - - def arcsin_(self) -> 'Array': - return Array(brainpy.math.arcsin(self.value, out=self.value)) - - def arcsin(self, *, out: Optional[Union['Array', jax.Array, np.ndarray]] = None) -> Union['Array', NoReturn]: - # return Array(self.value.__abs__()) - # if out is not None: - # if not isinstance(out, (Array, jax.Array, np.ndarray)): - # raise Exception('Unexcepted param out') - if not isinstance(out, Array): - out = Array(out) - return Array(brainpy.math.arcsin(self.value, out=out.value)) - - def arccos_(self) -> 'Array': - return Array(brainpy.math.arccos(self.value, out=self.value)) - - def arccos(self, *, out: Optional[Union['Array', jax.Array, np.ndarray]] = None) -> Union['Array', NoReturn]: - # return Array(self.value.__abs__()) - # if out is not None: - # if not isinstance(out, (Array, jax.Array, np.ndarray)): - # raise Exception('Unexcepted param out') - value = None - if out is not None: - if not isinstance(out, Array): - out = Array(out) - value = brainpy.math.arccos(self.value, out=out.value) + _check_out(out) + out.value = r + + def arccos_(self) -> None: + self.value = jnp.arccos(self.value) + + def arccos(self, *, out: Optional[Union['Array', jax.Array, np.ndarray]] = None) -> Optional['Array']: + r = jnp.arccos(self.value) + if out is None: + return _return(r) else: - value = brainpy.math.arccos(self.value) - return Array(value) - - def arctan_(self) -> 'Array': - return Array(brainpy.math.arctan(self.value, out=self.value)) - - def arctan(self, *, out: Optional[Union['Array', jax.Array, np.ndarray]] = None) -> Union['Array', NoReturn]: - # return Array(self.value.__abs__()) - # if out is not None: - # if not isinstance(out, (Array, jax.Array, np.ndarray)): - # raise Exception('Unexcepted param out') - value = None - if out is not None: - if not isinstance(out, Array): - out = Array(out) - value = brainpy.math.arctan(self.value, out=out.value) + _check_out(out) + out.value = r + + def arctan_(self) -> None: + self.value = jnp.arctan(self.value) + + def arctan(self, *, out: Optional[Union['Array', jax.Array, np.ndarray]] = None) -> Optional['Array']: + r = jnp.arctan(self.value) + if out is None: + return _return(r) else: - value = brainpy.math.arctan(self.value) - return Array(value) - - # def all(self, - # axis: Optional[int] = None, - # keepdim: bool = False, - # *, - # out: Optional[Union['Array', jax.Array, np.ndarray]] = None): - # """ - # test if all element cast to true - # """ - # # if out is not None: - # # if not isinstance(out, (Array, jax.Array, np.ndarray)): - # # raise Exception('Unexcepted param out') - # value = value = brainpy.math.all(self.value, axis, keepdim) - # if out is not None: - # if not isinstance(out, Array): - # warnings.showwarning("out is not a brainpy Array") - # out = Array(out) - # out.update(value) - # return Array(value) - - # def any(self, - # dim: int, - # keepdim: bool, - # *, - # out: Optional[Union['Array', jax.Array, np.ndarray]] = None): - # """ - # test if any element cast to true - # """ - # value = value = brainpy.math.any(self.value) - # if out is not None: - # if not isinstance(out, Array): - # warnings.showwarning("out is not a brainpy Array") - # out = Array(out) - # out.update(value) - # return Array(value) - - def clamp(self, - min_value: Optional[Union['Array', jax.Array, np.ndarray]] = None, - max_value: Optional[Union['Array', jax.Array, np.ndarray]] = None, - *, - out: Optional[Union['Array', jax.Array, np.ndarray]] = None) -> 'Array': + _check_out(out) + out.value = r + + def clamp( + self, + min_value: Optional[Union['Array', jax.Array, np.ndarray]] = None, + max_value: Optional[Union['Array', jax.Array, np.ndarray]] = None, + *, + out: Optional[Union['Array', jax.Array, np.ndarray]] = None + ) -> Optional['Array']: """ return the value between min_value and max_value, if min_value is None, then no lower bound, if max_value is None, then no upper bound. """ - # if out is not None: - # if not isinstance(out, (Array, jax.Array, np.ndarray)): - # raise Exception('Unexcepted param out') - - # value = None - # if out is not None: - # if not isinstance(out, Array): - # out = Array(out) - # value = brainpy.math.clip(self.value, min_value, max_value, out=out) - # else: - # value = brainpy.math.clip(self.value, min_value, max_value) - # return Array(value) - - return _return(self.value.clip(min_value, max_value)) + min_value = _as_jax_array_(min_value) + max_value = _as_jax_array_(max_value) + r = jnp.clip(self.value, max_value, max_value) + if out is None: + return _return(r) + else: + _check_out(out) + out.value = r def clamp_(self, - min_value: Optional[Union['Array', jax.Array, np.ndarray]] = None, - max_value: Optional[Union['Array', jax.Array, np.ndarray]] = None) -> 'Array': + min_value: Optional[Union['Array', jax.Array, np.ndarray]] = None, + max_value: Optional[Union['Array', jax.Array, np.ndarray]] = None) -> None: """ return the value between min_value and max_value, if min_value is None, then no lower bound, if max_value is None, then no upper bound. """ - return brainpy.math.clip(self.value, min_value, max_value, out=self) + self.clamp(min_value, max_value, out=self) def clip_(self, - min_value: Optional[Union['Array', jax.Array, np.ndarray]] = None, - max_value: Optional[Union['Array', jax.Array, np.ndarray]] = None) -> 'Array': - """ - alias for clamp_ - """ - return Array(brainpy.math.clip(self.value, min_value, max_value, out=self)) - - def clip(self, min_value: Optional[Union['Array', jax.Array, np.ndarray]] = None, - max_value: Optional[Union['Array', jax.Array, np.ndarray]] = None, - *, - out: Optional[Union['Array', jax.Array, np.ndarray]] = None) -> 'Array': + max_value: Optional[Union['Array', jax.Array, np.ndarray]] = None) -> None: """ - alias for clamp + alias for clamp_ """ - # if out is not None: - # if not isinstance(out, (Array, jax.Array, np.ndarray)): - # raise Exception('Unexcepted param out') - value = None - if out is not None: - if not isinstance(out, Array): - out = Array(out) - value = brainpy.math.clip(self.value, out=out) - else: - value = brainpy.math.clip(self.value) - return Array(value) + return self.clip(min_value, max_value, out=self) def clone(self) -> 'Array': - return Array(brainpy.math.copy(self.value)) - - def copy_(self, src: Union['Array', jax.Array, np.ndarray]) -> 'Array': - value = None - if src is not None: - if not isinstance(src, Array): - src = Array(src) - value = brainpy.math.copyto(self.value, src) - else: - raise Exception("copy from None???") - return self - - # def conj(self) -> 'Array': - # return Array(brainpy.math.conj(self.value)) - - def cov_with(self, - y: Optional[Union['Array', jax.Array, np.ndarray]] = None, - rowvar: bool = True, - bias: bool = False, - ddof: Optional[int] = None, - fweights: Union['Array', jax.Array, np.ndarray] = None, - aweights: Union['Array', jax.Array, np.ndarray] = None) -> 'Array': - return Array(brainpy.math.cov(self.value, y, rowvar, bias, fweights, aweights)) - - def cov(self, - *, - correction: int = 1, - fweights: Union['Array', jax.Array, np.ndarray] = None, - aweights: Union['Array', jax.Array, np.ndarray] = None) -> Union['Array', NoReturn]: - try: - x = [e[0] for e in self.value] - y = [e[1] for e in self.value] - return Array(brainpy.math.cov(x, y, ddof=correction, fweights=fweights, aweights=aweights)) - except Exception as e: - raise Exception('Wrong format, need to be [[x1,y1],[x2,y2],[x3,y3]]') + return Array(self.value.copy()) + def copy_(self, src: Union['Array', jax.Array, np.ndarray]) -> None: + self.value = jnp.copy(_as_jax_array_(src)) - # ------------------ - # Torch support - # ------------------ + def cov_with( + self, + y: Optional[Union['Array', jax.Array, np.ndarray]] = None, + rowvar: bool = True, + bias: bool = False, + ddof: Optional[int] = None, + fweights: Union['Array', jax.Array, np.ndarray] = None, + aweights: Union['Array', jax.Array, np.ndarray] = None + ) -> 'Array': + y = _as_jax_array_(y) + fweights = _as_jax_array_(fweights) + aweights = _as_jax_array_(aweights) + r = jnp.cov(self.value, y, rowvar, bias, fweights, aweights) + return Array(r) def expand(self, *sizes) -> 'Array': + """ + Expand an array to a new shape. + + Parameters + ---------- + shape : tuple or int + The shape of the desired array. A single integer ``i`` is interpreted + as ``(i,)``. + + Returns + ------- + expanded : Array + A readonly view on the original array with the given shape. It is + typically not contiguous. Furthermore, more than one element of a + expanded array may refer to a single memory location. + """ l_ori = len(self.shape) l_tar = len(sizes) base = l_tar - l_ori @@ -1537,12 +1415,14 @@ def expand(self, *sizes) -> 'Array': f'dimensions in the tensor ({len(self.shape)})') for i, v in enumerate(sizes[:base]): if v < 0: - raise ValueError(f'The expanded size of the tensor ({v}) isn\'t allowed in a leading, non-existing dimension {i + 1}') + raise ValueError( + f'The expanded size of the tensor ({v}) isn\'t allowed in a leading, non-existing dimension {i + 1}') for i, v in enumerate(self.shape): sizes_list[base + i] = v if sizes_list[base + i] == -1 else sizes_list[base + i] if v != 1 and sizes_list[base + i] != v: - raise ValueError(f'The expanded size of the tensor ({sizes_list[base + i]}) must match the existing size ({v}) at non-singleton ' - f'dimension {i}. Target sizes: {sizes}. Tensor sizes: {self.shape}') + raise ValueError( + f'The expanded size of the tensor ({sizes_list[base + i]}) must match the existing size ({v}) at non-singleton ' + f'dimension {i}. Target sizes: {sizes}. Tensor sizes: {self.shape}') return Array(jnp.broadcast_to(self.value, sizes_list)) From ca7f80a9fddf6280c280af35307fa1aa33ea5142 Mon Sep 17 00:00:00 2001 From: chaoming Date: Wed, 15 Feb 2023 13:56:58 +0800 Subject: [PATCH 2/6] ``share`` as the global context for sharing data across all modules/nodes --- brainpy/_src/experimental/delay.py | 293 ++--------------------------- brainpy/_src/math/context.py | 117 ++++++++++++ brainpy/_src/math/delayvars.py | 227 +++++++++++----------- brainpy/_src/math/environment.py | 78 +------- brainpy/math/context.py | 4 + brainpy/math/delayvars.py | 1 + brainpy/math/environment.py | 1 - 7 files changed, 265 insertions(+), 456 deletions(-) create mode 100644 brainpy/_src/math/context.py create mode 100644 brainpy/math/context.py diff --git a/brainpy/_src/experimental/delay.py b/brainpy/_src/experimental/delay.py index 8e8ad3cbf..d4222e30b 100644 --- a/brainpy/_src/experimental/delay.py +++ b/brainpy/_src/experimental/delay.py @@ -1,300 +1,47 @@ # -*- coding: utf-8 -*- -from typing import Union, Callable, Optional, Tuple, Sequence, Dict +from typing import Union, Callable, Optional, Dict import jax -import jax.numpy as jnp -import numpy as np -from jax.lax import stop_gradient -from brainpy import check, math as bm -from brainpy._src.math.object_transform.base import Collector +from brainpy import math as bm from brainpy._src.dyn.base import DynamicalSystem, not_pass_shargs -from brainpy.check import is_integer, jit_error_checking +from brainpy._src.math.delayvars import DelayVariable, ROTATE_UPDATE, CONCAT_UPDATE -ROTATE_UPDATE = 'rotation' -CONCAT_UPDATE = 'concat' +class Delay(DynamicalSystem, DelayVariable): + """Delay for dynamical systems which has a fixed delay length. -class Delay(DynamicalSystem): - """Delay variable which has a fixed delay length. - - The data in this delay variable is arranged as:: - - delay = 0 [ data - delay = 1 data - delay = 2 data - ... .... - ... .... - delay = length-1 data - delay = length data ] - - Parameters - ---------- - target: Variable - The initial delay data. - length: int - The delay data length. - initial_delay_data: Any - The delay data. It can be a Python number, like float, int, boolean values. - It can also be arrays. Or a callable function or instance of ``Connector``. - Note that ``initial_delay_data`` should be arranged as the following way:: - - delay = 1 [ data - delay = 2 data - ... .... - ... .... - delay = length-1 data - delay = length data ] - method: str - The method used for updating delay. - + Detailed docstring please see :py:class:`~.DelayVariable`. """ - data: Optional[bm.Variable] - idx: Optional[bm.Variable] - length: int - def __init__( self, target: bm.Variable, length: int = 0, - initial_delay_data: Union[float, int, bool, bm.Array, jax.Array, Callable] = None, + before_t0: Union[float, int, bool, bm.Array, jax.Array, Callable] = None, entries: Optional[Dict] = None, + method: str = None, mode: bm.Mode = None, name: str = None, - method: str = None, ): - super().__init__(mode=mode, name=name) - - # delay updating method if method is None: if self.mode.is_a(bm.NonBatchingMode): method = ROTATE_UPDATE - else: + elif self.mode.is_parent_of(bm.TrainingMode): method = CONCAT_UPDATE - assert method in [ROTATE_UPDATE, CONCAT_UPDATE] - self.method = method - - # target - self.target = target - if not isinstance(target, bm.Variable): - raise ValueError(f'Must be an instance of brainpy.math.Variable. But we got {type(target)}') - - # delay length - self.length = is_integer(length, allow_none=False, min_bound=0) - - # delay data - if initial_delay_data is not None: - assert isinstance(initial_delay_data, (int, float, bool, bm.Array, jax.Array, Callable)) - self._initial_delay_data = initial_delay_data - if length > 0: - self._init_data(length) - else: - self.data = None - - # time variables - if self.method == ROTATE_UPDATE: - self.idx = bm.Variable(stop_gradient(jnp.asarray(0, dtype=jnp.int32))) - - # other info - self._access_to_step = dict() - for entry, value in entries.items(): - self.register_entry(entry, value) - - def register_entry( - self, - entry: str, - delay_time: Optional[Union[float, bm.Array, Callable]] = None, - delay_step: Optional[Union[int, bm.Array, Callable]] = None, - ) -> 'Delay': - """Register an entry to access the data. - - Args: - entry (str): The entry to access the delay data. - delay_step: The delay step of the entry (must be an integer, denoting the delay step). - delay_time: The delay time of the entry (can be a float). - - Returns: - Return the self. - """ - if entry in self._access_to_step: - raise KeyError(f'Entry {entry} has been registered.') - - if delay_time is not None: - if delay_step is not None: - raise ValueError('Provide either "delay_time" or "delay_step". Both you have given both.') - if callable(delay_time): - delay_time = bm.as_jax(delay_time(self.delay_target_shape)) - delay_step = jnp.asarray(delay_time / bm.get_dt(), dtype=bm.get_int()) - elif isinstance(delay_time, float): - delay_step = int(delay_time / bm.get_dt()) - else: - delay_step = jnp.asarray(bm.as_jax(delay_time) / bm.get_dt(), dtype=bm.get_int()) - - # delay steps - if delay_step is None: - delay_type = 'none' - elif isinstance(delay_step, int): - delay_type = 'homo' - elif isinstance(delay_step, (bm.Array, jax.Array, np.ndarray)): - if delay_step.size == 1 and delay_step.ndim == 0: - delay_type = 'homo' else: - delay_type = 'heter' - delay_step = bm.Array(delay_step) - elif callable(delay_step): - delay_step = delay_step(self.delay_target_shape) - delay_type = 'heter' - else: - raise ValueError(f'Unknown "delay_steps" type {type(delay_step)}, only support ' - f'integer, array of integers, callable function, brainpy.init.Initializer.') - if delay_type == 'heter': - if delay_step.dtype not in [jnp.int32, jnp.int64]: - raise ValueError('Only support delay steps of int32, int64. If your ' - 'provide delay time length, please divide the "dt" ' - 'then provide us the number of delay steps.') - if self.delay_target_shape[0] != delay_step.shape[0]: - raise ValueError(f'Shape is mismatched: {self.delay_target_shape[0]} != {delay_step.shape[0]}') - if delay_type == 'heter': - max_delay_step = int(max(delay_step)) - elif delay_type == 'homo': - max_delay_step = delay_step - else: - max_delay_step = None - - # delay variable - if max_delay_step is not None: - if self.length < max_delay_step: - self._init_data(max_delay_step) - self.length = max_delay_step - self._access_to_step[entry] = delay_step - return self - - def at_entry(self, entry: str, *indices) -> bm.Array: - """Get the data at the given entry. - - Args: - entry (str): The entry to access the data. - *indices: - - Returns: - The data. - """ - assert isinstance(entry, str) - if entry not in self._access_to_step: - raise KeyError(f'Does not find delay entry "{entry}".') - delay_step = self._access_to_step[entry] - if delay_step is None: - return self.target.value - else: - if self.data is None: - return self.target.value - else: - if isinstance(delay_step, slice): - return self.retrieve(delay_step, *indices) - elif np.ndim(delay_step) == 0: - return self.retrieve(delay_step, *indices) - else: - if len(indices) == 0 and len(delay_step) == self.target.shape[0]: - indices = (jnp.arange(delay_step.size),) - return self.retrieve(delay_step, *indices) - - @property - def delay_target_shape(self): - """The data shape of the delay target.""" - return self.target.shape - - def __repr__(self): - name = self.__class__.__name__ - return (f'{name}(num_delay_step={self.length}, ' - f'delay_target_shape={self.delay_target_shape}, ' - f'update_method={self.method})') - - def _check_delay(self, delay_len): - raise ValueError(f'The request delay length should be less than the ' - f'maximum delay {self.length}. ' - f'But we got {delay_len}') - - def retrieve(self, delay_step, *indices): - """Retrieve the delay data according to the delay length. - - Parameters - ---------- - delay_step: int, ArrayType - The delay length used to retrieve the data. - """ - assert delay_step is not None - if check.is_checking(): - jit_error_checking(jnp.any(delay_step > self.length), self._check_delay, delay_step) - - if self.method == ROTATE_UPDATE: - delay_idx = (self.idx.value + delay_step) % (self.length + 1) - delay_idx = stop_gradient(delay_idx) - - elif self.method == CONCAT_UPDATE: - delay_idx = delay_step - - else: - raise ValueError(f'Unknown updating method "{self.method}"') - - # the delay index - if hasattr(delay_idx, 'dtype') and not jnp.issubdtype(delay_idx.dtype, jnp.integer): - raise ValueError(f'"delay_len" must be integer, but we got {delay_idx}') - indices = (delay_idx,) + tuple(indices) - - # the delay data - return self.data[indices] + method = ROTATE_UPDATE + DynamicalSystem.__init__(self, mode=mode) + DelayVariable.__init__(self, + target=target, + length=length, + before_t0=before_t0, + entries=entries, + method=method, + name=name) @not_pass_shargs - def update(self, latest_value: Optional[Union[bm.Array, jax.Array]] = None) -> None: - """Update delay variable with the new data. - """ - if self.data is not None: - # get the latest target value - if latest_value is None: - latest_value = self.target.value - - # update the delay data at the rotation index - if self.method == ROTATE_UPDATE: - self.idx.value = stop_gradient(bm.as_jax((self.idx - 1) % (self.length + 1))) - self.data[self.idx.value] = latest_value - - # update the delay data at the first position - elif self.method == CONCAT_UPDATE: - if self.length >= 2: - self.data.value = bm.vstack([latest_value, self.data[1:]]) - else: - self.data[0] = latest_value - - def reset_state(self, batch_size: int = None): - """Reset the delay data. - """ - # initialize delay data - if self.data is not None: - self._init_data(self.length, batch_size) - - # time variables - if self.method == ROTATE_UPDATE: - self.idx.value = stop_gradient(jnp.asarray(0, dtype=jnp.int32)) - - def _init_data(self, length, batch_size: int = None): - if batch_size is not None: - if self.target.batch_size != batch_size: - raise ValueError(f'The batch sizes of delay variable and target variable differ ' - f'({self.target.batch_size} != {batch_size}). ' - 'Please reset the target variable first, because delay data ' - 'depends on the target variable. ') + def update(self, *args, **kwargs): + return DelayVariable.update(self, *args, **kwargs) - if self.target.batch_axis is None: - batch_axis = None - else: - batch_axis = self.target.batch_axis + 1 - self.data = bm.Variable(jnp.zeros((length + 1,) + self.target.shape, dtype=self.target.dtype), - batch_axis=batch_axis) - # update delay data - self.data[0] = self.target.value - if isinstance(self._initial_delay_data, (bm.Array, jax.Array, float, int, bool)): - self.data[1:] = self._initial_delay_data - elif callable(self._initial_delay_data): - self.data[1:] = self._initial_delay_data((length,) + self.target.shape, dtype=self.target.dtype) diff --git a/brainpy/_src/math/context.py b/brainpy/_src/math/context.py new file mode 100644 index 000000000..7283569cc --- /dev/null +++ b/brainpy/_src/math/context.py @@ -0,0 +1,117 @@ +""" +Context for brainpy computation. + +This context defines all shared data used in all modules in a computation. +""" + +from typing import Dict, Any + +from brainpy._src.tools.dicts import DotDict +from .delayvars import DelayVariable +from .object_transform.base import BrainPyObject +from .environment import get_dt as _get_dt_ + +__all__ = [ + 'share', +] + + +class DelayEntry: + def __init__(self, target: str, time=None, step=None): + if time is None and step is None: + raise ValueError('Please provide time or step.') + self.target = target + self.time = time + self.step = step + + +class _ShareContext(BrainPyObject): + def __init__(self): + super().__init__() + + # Shared data across all nodes at current time step. + # ------------- + + self._arguments = DotDict() + self._delays: Dict[str, DelayVariable] = DotDict() + self._delay_entries: Dict[str, str] = DotDict() + self._identifiers = set() + + @property + def dt(self): + if 'dt' in self._arguments: + return self._arguments['dt'] + else: + return _get_dt_() + + def load(self, key): + """Get the shared data by the ``key``. + + Args: + key (str): the key to indicate the data. + """ + if key in self._arguments: + return self._arguments[key] + if key in self._delays: + return self._delays[key] + if key in self._delay_entries: + entry = key + delay = self._delay_entries[entry] + return self._delays[delay].at(entry) + raise KeyError(f'Cannot found shared data of {key}.') + + def save(self, identifier: str, data: Any) -> None: + """Save shared arguments in the global context.""" + assert isinstance(identifier, str) + + if identifier in self._identifiers: + raise ValueError(f'{identifier} has been used. Please assign another name.') + if isinstance(data, DelayVariable): + self._delays[identifier] = data + elif isinstance(data, DelayEntry): + if isinstance(data.target, DelayVariable): + delay_key = f'delay{id(data)}' + self.save(delay_key, data.target) + delay = data.target + elif isinstance(data.target, str): + if data.target not in self._delays: + raise ValueError(f'Delay target {data.target} has not been registered.') + delay = self._delays[data.target] + delay_key = data.target + else: + raise ValueError(f'Unknown delay target. {type(data.target)}') + delay.register_entry(identifier, delay_time=data.time, delay_step=data.step) + self._delay_entries[identifier] = delay_key + else: + self._arguments[identifier] = data + self._identifiers.add(identifier) + + def get_shargs(self) -> DotDict: + """Get all shared arguments in the global context.""" + return self._arguments.copy() + + def remove_shargs(self, *args) -> None: + """Clear all shared arguments in the global context.""" + if len(args) > 0: + for a in args: + self._arguments.pop(a) + else: + self._arguments.clear() + + def clear(self) -> None: + """Clear all shared data in this computation context.""" + self._arguments.clear() + self._delays.clear() + self._delay_entries.clear() + self._identifiers.clear() + + def update(self): + for delay in self._delays.values(): + delay.update() + + def reset_state(self, batch_axis: int = None): + for delay in self._delays.values(): + delay.reset_state(batch_axis) + + +share = _ShareContext() diff --git a/brainpy/_src/math/delayvars.py b/brainpy/_src/math/delayvars.py index fcb4af366..e17728ca2 100644 --- a/brainpy/_src/math/delayvars.py +++ b/brainpy/_src/math/delayvars.py @@ -1,20 +1,20 @@ # -*- coding: utf-8 -*- -from typing import Union, Callable, Optional +from typing import Union, Callable, Optional, Dict -import numpy as np import jax import jax.numpy as jnp +import numpy as np from jax import vmap from jax.lax import cond, stop_gradient from brainpy import check, math as bm from brainpy.check import is_float, is_integer, jit_error_checking from brainpy.errors import UnsupportedError -from .object_transform.base import BrainPyObject -from .ndarray import ndarray, Variable, Array from .arrayinterporate import as_jax from .environment import get_dt, get_int +from .ndarray import ndarray, Variable, Array +from .object_transform.base import BrainPyObject __all__ = [ 'AbstractDelay', @@ -31,8 +31,7 @@ def _as_jax_array(arr): class AbstractDelay(BrainPyObject): - def update(self, *args, **kwargs): - raise NotImplementedError + pass _FUNC_BEFORE = 'function' @@ -474,85 +473,87 @@ def update(self, value: Union[float, int, bool, Array, jnp.DeviceArray]): class DelayVariable(AbstractDelay): """Delay variable which has a fixed delay length. - The data in this delay variable is arranged as:: - - delay = 0 [ data - delay = 1 data - delay = 2 data - ... .... - ... .... - delay = length-1 data - delay = length data ] - - Parameters - ---------- - target: Variable - The initial delay data. - length: int - The delay data length. - initial_delay_data: Any - The delay data. It can be a Python number, like float, int, boolean values. - It can also be arrays. Or a callable function or instance of ``Connector``. - Note that ``initial_delay_data`` should be arranged as the following way:: + The data in this delay variable is arranged as:: - delay = 1 [ data - delay = 2 data - ... .... - ... .... - delay = length-1 data - delay = length data ] + delay = 0 [ data + delay = 1 data + delay = 2 data + ... .... + ... .... + delay = length-1 data + delay = length data ] - update_method: str - The method used for updating delay. + Parameters + ---------- + target: Variable + The initial delay data. + length: int + The delay data length. + before_t0: Any + The delay data. It can be a Python number, like float, int, boolean values. + It can also be arrays. Or a callable function or instance of ``Connector``. + Note that ``initial_delay_data`` should be arranged as the following way:: + + delay = 1 [ data + delay = 2 data + ... .... + ... .... + delay = length-1 data + delay = length data ] + method: str + The method used for updating delay. - See Also - -------- - TimeDelay - """ + """ - data: Optional[Variable] - idx: Optional[Variable] + data: Optional[bm.Variable] + idx: Optional[bm.Variable] length: int def __init__( self, - target: Variable, + target: bm.Variable, length: int = 0, - initial_delay_data: Union[float, int, bool, Array, jax.Array, Callable] = None, - update_method: str = ROTATE_UPDATE + before_t0: Union[float, int, bool, bm.Array, jax.Array, Callable] = None, + entries: Optional[Dict] = None, + name: str = None, + method: str = ROTATE_UPDATE, ): - super().__init__() - - assert update_method in [ROTATE_UPDATE, CONCAT_UPDATE] - self.update_method = update_method + super().__init__(name=name) + assert method in [ROTATE_UPDATE, CONCAT_UPDATE] + self.method = method # target self.target = target - if not isinstance(target, Variable): + if not isinstance(target, bm.Variable): raise ValueError(f'Must be an instance of brainpy.math.Variable. But we got {type(target)}') # delay length self.length = is_integer(length, allow_none=False, min_bound=0) # delay data - if initial_delay_data is not None: - assert isinstance(initial_delay_data, (int, float, bool, Array, jax.Array, Callable)) - self._initial_delay_data = initial_delay_data - self._init_data(length) + if before_t0 is not None: + assert isinstance(before_t0, (int, float, bool, bm.Array, jax.Array, Callable)) + self._before_t0 = before_t0 + if length > 0: + self._init_data(length) + else: + self.data = None # time variables - if self.update_method == ROTATE_UPDATE: - self.idx = Variable(stop_gradient(jnp.asarray(0, dtype=jnp.int32))) + if self.method == ROTATE_UPDATE: + self.idx = bm.Variable(stop_gradient(jnp.asarray(0, dtype=jnp.int32))) # other info self._access_to_step = dict() + for entry, value in entries.items(): + self.register_entry(entry, value) def register_entry( self, entry: str, - delay_step: Optional[Union[int, Array, Callable]] = None, - delay_time: Optional[Union[float, Array, Callable]] = None, - ) -> 'DelayVariable': + delay_time: Optional[Union[float, bm.Array, Callable]] = None, + delay_step: Optional[Union[int, bm.Array, Callable]] = None, + ) -> 'Delay': """Register an entry to access the data. Args: @@ -563,28 +564,31 @@ def register_entry( Returns: Return the self. """ + if entry in self._access_to_step: + raise KeyError(f'Entry {entry} has been registered.') + if delay_time is not None: if delay_step is not None: raise ValueError('Provide either "delay_time" or "delay_step". Both you have given both.') if callable(delay_time): - delay_time = _as_jax_array(delay_time(self.delay_target_shape)) - delay_step = jnp.asarray(delay_time / get_dt(), dtype=get_int()) + delay_time = bm.as_jax(delay_time(self.delay_target_shape)) + delay_step = jnp.asarray(delay_time / bm.get_dt(), dtype=bm.get_int()) elif isinstance(delay_time, float): - delay_step = int(delay_time / get_dt()) + delay_step = int(delay_time / bm.get_dt()) else: - delay_step = jnp.asarray(_as_jax_array(delay_time) / get_dt(), dtype=get_int()) + delay_step = jnp.asarray(bm.as_jax(delay_time) / bm.get_dt(), dtype=bm.get_int()) # delay steps if delay_step is None: delay_type = 'none' elif isinstance(delay_step, int): delay_type = 'homo' - elif isinstance(delay_step, (Array, jax.Array, np.ndarray)): + elif isinstance(delay_step, (bm.Array, jax.Array, np.ndarray)): if delay_step.size == 1 and delay_step.ndim == 0: delay_type = 'homo' else: delay_type = 'heter' - delay_step = Array(delay_step) + delay_step = bm.Array(delay_step) elif callable(delay_step): delay_step = delay_step(self.delay_target_shape) delay_type = 'heter' @@ -613,7 +617,7 @@ def register_entry( self._access_to_step[entry] = delay_step return self - def at_entry(self, entry: str, *indices) -> Array: + def at(self, entry: str, *indices) -> bm.Array: """Get the data at the given entry. Args: @@ -625,20 +629,22 @@ def at_entry(self, entry: str, *indices) -> Array: """ assert isinstance(entry, str) if entry not in self._access_to_step: - raise KeyError(f'Does not find delay access "{entry}".') + raise KeyError(f'Does not find delay entry "{entry}".') delay_step = self._access_to_step[entry] if delay_step is None: return self.target.value else: - assert self.data is not None - if isinstance(delay_step, slice): - return self.retrieve(delay_step, *indices) - elif np.ndim(delay_step) == 0: - return self.retrieve(delay_step, *indices) + if self.data is None: + return self.target.value else: - if len(indices) == 0 and len(delay_step) == self.target.shape[0]: - indices = (jnp.arange(delay_step.size),) - return self.retrieve(delay_step, *indices) + if isinstance(delay_step, slice): + return self.retrieve(delay_step, *indices) + elif np.ndim(delay_step) == 0: + return self.retrieve(delay_step, *indices) + else: + if len(indices) == 0 and len(delay_step) == self.target.shape[0]: + indices = (jnp.arange(delay_step.size),) + return self.retrieve(delay_step, *indices) @property def delay_target_shape(self): @@ -649,36 +655,34 @@ def __repr__(self): name = self.__class__.__name__ return (f'{name}(num_delay_step={self.length}, ' f'delay_target_shape={self.delay_target_shape}, ' - f'update_method={self.update_method})') + f'update_method={self.method})') def _check_delay(self, delay_len): raise ValueError(f'The request delay length should be less than the ' f'maximum delay {self.length}. ' f'But we got {delay_len}') - def __call__(self, delay_len, *indices): - return self.retrieve(delay_len, *indices) - def retrieve(self, delay_step, *indices): - """Retrieve the delay data acoording to the delay length. + """Retrieve the delay data according to the delay length. Parameters ---------- delay_step: int, ArrayType The delay length used to retrieve the data. """ + assert delay_step is not None if check.is_checking(): jit_error_checking(jnp.any(delay_step > self.length), self._check_delay, delay_step) - if self.update_method == ROTATE_UPDATE: - delay_idx = (self.idx.value + delay_step) % self.length + if self.method == ROTATE_UPDATE: + delay_idx = (self.idx.value + delay_step) % (self.length + 1) delay_idx = stop_gradient(delay_idx) - elif self.update_method == CONCAT_UPDATE: + elif self.method == CONCAT_UPDATE: delay_idx = delay_step else: - raise ValueError(f'Unknown updating method "{self.update_method}"') + raise ValueError(f'Unknown updating method "{self.method}"') # the delay index if hasattr(delay_idx, 'dtype') and not jnp.issubdtype(delay_idx.dtype, jnp.integer): @@ -688,44 +692,57 @@ def retrieve(self, delay_step, *indices): # the delay data return self.data[indices] - def update(self): + def update(self, latest_value: Optional[Union[bm.Array, jax.Array]] = None) -> None: """Update delay variable with the new data. """ - # update the delay data at the rotation index - if self.update_method == ROTATE_UPDATE: - self.idx.value = stop_gradient(as_jax((self.idx - 1) % self.length)) - self.data[self.idx.value] = self.target.value - - # update the delay data at the first position - elif self.update_method == CONCAT_UPDATE: - if self.length >= 2: - self.data.value = bm.vstack([self.target.value, self.data[1:]]) - else: - self.data[0] = self.target.value - - def reset(self): + if self.data is not None: + # get the latest target value + if latest_value is None: + latest_value = self.target.value + + # update the delay data at the rotation index + if self.method == ROTATE_UPDATE: + self.idx.value = stop_gradient(bm.as_jax((self.idx - 1) % (self.length + 1))) + self.data[self.idx.value] = latest_value + + # update the delay data at the first position + elif self.method == CONCAT_UPDATE: + if self.length >= 2: + self.data.value = bm.vstack([latest_value, self.data[1:]]) + else: + self.data[0] = latest_value + + def reset_state(self, batch_size: int = None): """Reset the delay data. """ # initialize delay data - self._init_data(self.length) + if self.data is not None: + self._init_data(self.length, batch_size) # time variables - if self.update_method == ROTATE_UPDATE: + if self.method == ROTATE_UPDATE: self.idx.value = stop_gradient(jnp.asarray(0, dtype=jnp.int32)) - def _init_data(self, length): + def _init_data(self, length, batch_size: int = None): + if batch_size is not None: + if self.target.batch_size != batch_size: + raise ValueError(f'The batch sizes of delay variable and target variable differ ' + f'({self.target.batch_size} != {batch_size}). ' + 'Please reset the target variable first, because delay data ' + 'depends on the target variable. ') + if self.target.batch_axis is None: batch_axis = None else: batch_axis = self.target.batch_axis + 1 - self.data = Variable(jnp.zeros((length + 1,) + self.target.shape, dtype=self.target.dtype), - batch_axis=batch_axis) + self.data = bm.Variable(jnp.zeros((length + 1,) + self.target.shape, dtype=self.target.dtype), + batch_axis=batch_axis) # update delay data self.data[0] = self.target.value - if isinstance(self._initial_delay_data, (Array, jax.Array, float, int, bool)): - self.data[1:] = self._initial_delay_data - elif callable(self._initial_delay_data): - self.data[1:] = self._initial_delay_data((length,) + self.target.shape, dtype=self.target.dtype) + if isinstance(self._before_t0, (bm.Array, jax.Array, float, int, bool)): + self.data[1:] = self._before_t0 + elif callable(self._before_t0): + self.data[1:] = self._before_t0((length,) + self.target.shape, dtype=self.target.dtype) class NeuLenDelay(LengthDelay): diff --git a/brainpy/_src/math/environment.py b/brainpy/_src/math/environment.py index 4b73050c5..c9eeec3b7 100644 --- a/brainpy/_src/math/environment.py +++ b/brainpy/_src/math/environment.py @@ -6,23 +6,16 @@ import os import re import sys -from typing import Any, Callable, TypeVar, cast, Dict, Union, Optional +from typing import Any, Callable, TypeVar, cast -import jax -import numpy as np from jax import config, numpy as jnp, devices from jax.lib import xla_bridge -from brainpy._src.tools.dicts import DotDict from . import modes -# from .delayvars import LengthDelay, ROTATE_UPDATE -from .ndarray import Variable, Array bm = None __all__ = [ - 'share', - # default data types 'set_float', 'get_float', 'set_int', 'get_int', @@ -58,75 +51,6 @@ ] -def _key_of_var(var: Variable): - if not isinstance(var, Variable): - raise TypeError(f'Delay target should be instance of Variable. But got {type(var)}') - return f'var{id(var)}' - - -def _as_jax_array(arr): - return arr.value if isinstance(arr, Array) else arr - - -class Context: - """Context for brainpy computation.""" - - def __init__(self): - """Initialize function.""" - - '''Shared data across all nodes at current time step. - ''' - self._arguments = DotDict() - - def get(self, key): - """Get the shared data by the ``key``. - - Args: - key (str): the key to indicate the data. - """ - if key in self._arguments: - return self._arguments[key] - else: - raise KeyError(f'Cannot found shared data of {key}.') - - # shared arguments # - # ---------------- # - - def save_shargs(self, **shared) -> None: - """Save shared arguments in the global context.""" - self._arguments.update(shared) - - def get_shargs(self) -> DotDict: - """Get all shared arguments in the global context.""" - r = self._arguments.copy() - return r - - def remove_shargs(self, *args) -> None: - """Clear all shared arguments in the global context.""" - if len(args) > 0: - for a in args: - self._arguments.pop(a) - else: - self._arguments.clear() - - # other # - # ----- # - - def clear(self) -> None: - """Clear all shared data in this computation context.""" - self.remove_shargs() - - -share = Context() -'''Global context manager to manage ``share`` data across all modules.''' - - -def change_share_context(context: Context): - global share - assert isinstance(context, Context), f'Must be instance of {Context.__name__}' - share = context - - # default dtype # -------------------------- diff --git a/brainpy/math/context.py b/brainpy/math/context.py new file mode 100644 index 000000000..19631e022 --- /dev/null +++ b/brainpy/math/context.py @@ -0,0 +1,4 @@ + +from brainpy._src.math.context import ( + share as share, +) diff --git a/brainpy/math/delayvars.py b/brainpy/math/delayvars.py index bae3b9cc2..e1249e751 100644 --- a/brainpy/math/delayvars.py +++ b/brainpy/math/delayvars.py @@ -5,6 +5,7 @@ LengthDelay as LengthDelay, NeuTimeDelay as NeuTimeDelay, NeuLenDelay as NeuLenDelay, + DelayVariable as DelayVariable, ROTATE_UPDATE as ROTATE_UPDATE, CONCAT_UPDATE as CONCAT_UPDATE, ) diff --git a/brainpy/math/environment.py b/brainpy/math/environment.py index c7e5df414..3c0730d72 100644 --- a/brainpy/math/environment.py +++ b/brainpy/math/environment.py @@ -1,7 +1,6 @@ # -*- coding: utf-8 -*- from brainpy._src.math.environment import ( - share as share, set_float as set_float, get_float as get_float, set_int as set_int, From 899fba36dabe526e639fda0875baca95422e9e5b Mon Sep 17 00:00:00 2001 From: chaoming Date: Wed, 15 Feb 2023 13:57:11 +0800 Subject: [PATCH 3/6] update `share` usages --- brainpy/_src/dyn/base.py | 2 +- brainpy/_src/dyn/runners.py | 2 +- brainpy/_src/experimental/neurons.py | 4 +--- brainpy/_src/experimental/synapses.py | 2 +- brainpy/_src/experimental/synstp.py | 4 ++-- brainpy/_src/math/__init__.py | 1 + brainpy/_src/math/activations.py | 5 ----- brainpy/_src/train/back_propagation.py | 2 +- brainpy/math/__init__.py | 1 + 9 files changed, 9 insertions(+), 14 deletions(-) diff --git a/brainpy/_src/dyn/base.py b/brainpy/_src/dyn/base.py index 2963a691e..17522466c 100644 --- a/brainpy/_src/dyn/base.py +++ b/brainpy/_src/dyn/base.py @@ -155,7 +155,7 @@ def __call__(self, *args, **kwargs): """The shortcut to call ``update`` methods.""" if hasattr(self.update, '_new_style') and getattr(self.update, '_new_style'): if len(args) and isinstance(args[0], dict): - bm.share.save_shargs(**args[0]) + bm.share.save(**args[0]) return self.update(*args[1:], **kwargs) else: return self.update(*args, **kwargs) diff --git a/brainpy/_src/dyn/runners.py b/brainpy/_src/dyn/runners.py index fe7ae9fc8..939cc5b94 100644 --- a/brainpy/_src/dyn/runners.py +++ b/brainpy/_src/dyn/runners.py @@ -615,7 +615,7 @@ def _step_func_predict(self, shared_args, t, i, x): # input step shared = tools.DotDict(t=t, i=i, dt=self.dt) shared.update(shared_args) - bm.share.save_shargs(**shared) + bm.share.save(**shared) self.target.clear_input() self._step_func_input(shared) diff --git a/brainpy/_src/experimental/neurons.py b/brainpy/_src/experimental/neurons.py index d45acc862..e3c31f55b 100644 --- a/brainpy/_src/experimental/neurons.py +++ b/brainpy/_src/experimental/neurons.py @@ -52,8 +52,6 @@ class LIF(NeuGroup): Refractory period length.(ms) V_initializer: ArrayType, Initializer, callable The initializer of membrane potential. - noise: ArrayType, Initializer, callable - The noise added onto the membrane potential method: str The numerical integration method. name: str @@ -125,7 +123,7 @@ def reset_state(self, batch_size=None): @not_pass_shargs def update(self, current): - t = bm.share.get('t') + t = bm.share.load('t') # integrate membrane potential V = self.integral(self.V.value, t, current, bm.dt) diff --git a/brainpy/_src/experimental/synapses.py b/brainpy/_src/experimental/synapses.py index 5bb47df55..1424e3902 100644 --- a/brainpy/_src/experimental/synapses.py +++ b/brainpy/_src/experimental/synapses.py @@ -255,7 +255,7 @@ def update(self, pre_spike): post_vs = self._syn2post_with_dense(syn_value, self.g_max, self.conn_mask) # updates - self.g.value = self.integral(self.g.value, bm.share.get('t'), bm.dt) + post_vs + self.g.value = self.integral(self.g.value, bm.share.load('t'), bm.dt) + post_vs # outputs if self.out is not None: diff --git a/brainpy/_src/experimental/synstp.py b/brainpy/_src/experimental/synstp.py index e4aac3a22..7401ed6ea 100644 --- a/brainpy/_src/experimental/synstp.py +++ b/brainpy/_src/experimental/synstp.py @@ -86,7 +86,7 @@ def reset_state(self, batch_size=None): @not_pass_shargs def update(self, pre_spike): - x = self.integral(self.x.value, bm.share.get('t'), bm.share.get('dt')) + x = self.integral(self.x.value, bm.share.load('t'), bm.share.load('dt')) self.x.value = bm.where(pre_spike, x - self.U * self.x, x) return self.x.value @@ -169,7 +169,7 @@ def reset_state(self, batch_size=None): @not_pass_shargs def update(self, pre_spike): - u, x = self.integral(self.u.value, self.x.value, bm.share.get('t'), bm.get_dt()) + u, x = self.integral(self.u.value, self.x.value, bm.share.load('t'), bm.get_dt()) u = bm.where(pre_spike, u + self.U * (1 - self.u), u) x = bm.where(pre_spike, x - u * self.x, x) self.x.value = x diff --git a/brainpy/_src/math/__init__.py b/brainpy/_src/math/__init__.py index e852fc710..1a6fac48f 100644 --- a/brainpy/_src/math/__init__.py +++ b/brainpy/_src/math/__init__.py @@ -57,4 +57,5 @@ # environment settings from .modes import * from .environment import * +from .context import share diff --git a/brainpy/_src/math/activations.py b/brainpy/_src/math/activations.py index 59f6948ac..7d558b70f 100644 --- a/brainpy/_src/math/activations.py +++ b/brainpy/_src/math/activations.py @@ -45,7 +45,6 @@ 'swish', 'selu', 'identity', - 'tanh', ] @@ -66,10 +65,6 @@ def get(activation): return global_vars[activation] -def tanh(x): - return jnp.tanh((x.value if isinstance(x, Array) else x)) - - def identity(x): return x.value if isinstance(x, Array) else x diff --git a/brainpy/_src/train/back_propagation.py b/brainpy/_src/train/back_propagation.py index a9e852feb..934009abc 100644 --- a/brainpy/_src/train/back_propagation.py +++ b/brainpy/_src/train/back_propagation.py @@ -566,7 +566,7 @@ def _step_func_fit(self, shared_args, inputs, targets): def _step_func_predict(self, shared, x=None): assert self.data_first_axis == 'B', f'There is no time dimension when using the trainer {self.__class__.__name__}.' - bm.share.save_shargs(**shared) + bm.share.save(**shared) # input step self.target.clear_input() diff --git a/brainpy/math/__init__.py b/brainpy/math/__init__.py index e8456d6b0..3d1582f19 100644 --- a/brainpy/math/__init__.py +++ b/brainpy/math/__init__.py @@ -26,6 +26,7 @@ from .modes import * from .environment import * from .others import * +from .context import share mode = NonBatchingMode() '''Default computation mode.''' From 720be63920af410098ba79e0d42791782b3978c8 Mon Sep 17 00:00:00 2001 From: chaoming Date: Wed, 15 Feb 2023 14:30:23 +0800 Subject: [PATCH 4/6] fix bug --- brainpy/_src/math/delayvars.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/brainpy/_src/math/delayvars.py b/brainpy/_src/math/delayvars.py index e17728ca2..41a37dfaa 100644 --- a/brainpy/_src/math/delayvars.py +++ b/brainpy/_src/math/delayvars.py @@ -553,7 +553,7 @@ def register_entry( entry: str, delay_time: Optional[Union[float, bm.Array, Callable]] = None, delay_step: Optional[Union[int, bm.Array, Callable]] = None, - ) -> 'Delay': + ) -> 'DelayVariable': """Register an entry to access the data. Args: From b25d7173c9e3a541646204ac0d68b5d1560b04d9 Mon Sep 17 00:00:00 2001 From: chaoming Date: Wed, 15 Feb 2023 14:36:51 +0800 Subject: [PATCH 5/6] fix bugs in delay vars --- brainpy/_src/math/delayvars.py | 53 +++++++++++++++++----------------- 1 file changed, 27 insertions(+), 26 deletions(-) diff --git a/brainpy/_src/math/delayvars.py b/brainpy/_src/math/delayvars.py index 41a37dfaa..62a71d0fd 100644 --- a/brainpy/_src/math/delayvars.py +++ b/brainpy/_src/math/delayvars.py @@ -8,11 +8,12 @@ from jax import vmap from jax.lax import cond, stop_gradient -from brainpy import check, math as bm +from brainpy import check from brainpy.check import is_float, is_integer, jit_error_checking from brainpy.errors import UnsupportedError from .arrayinterporate import as_jax -from .environment import get_dt, get_int +from .compat_numpy import vstack, broadcast_to +from .environment import get_dt, get_float, get_int from .ndarray import ndarray, Variable, Array from .object_transform.base import BrainPyObject @@ -132,7 +133,7 @@ def __init__( # delay_len self.t0 = t0 - self.dt = bm.get_dt() if dt is None else dt + self.dt = get_dt() if dt is None else dt is_float(delay_len, 'delay_len', allow_none=False, allow_int=True, min_bound=0.) self.delay_len = delay_len self.num_delay_step = int(jnp.ceil(self.delay_len / self.dt)) + 1 @@ -146,7 +147,7 @@ def __init__( # time variables self.idx = Variable(jnp.asarray([0])) is_float(t0, 't0', allow_none=False, allow_int=True, ) - self.current_time = Variable(jnp.asarray([t0], dtype=bm.get_float())) + self.current_time = Variable(jnp.asarray([t0], dtype=get_float())) # delay data batch_axis = None @@ -462,7 +463,7 @@ def update(self, value: Union[float, int, bool, Array, jnp.DeviceArray]): elif self.update_method == CONCAT_UPDATE: if self.num_delay_step >= 2: - self.data.value = bm.vstack([bm.broadcast_to(value, self.data.shape[1:]), self.data[1:]]) + self.data.value = vstack([broadcast_to(value, self.data.shape[1:]), self.data[1:]]) else: self.data[:] = value @@ -505,15 +506,15 @@ class DelayVariable(AbstractDelay): """ - data: Optional[bm.Variable] - idx: Optional[bm.Variable] + data: Optional[Variable] + idx: Optional[Variable] length: int def __init__( self, - target: bm.Variable, + target: Variable, length: int = 0, - before_t0: Union[float, int, bool, bm.Array, jax.Array, Callable] = None, + before_t0: Union[float, int, bool, Array, jax.Array, Callable] = None, entries: Optional[Dict] = None, name: str = None, method: str = ROTATE_UPDATE, @@ -524,7 +525,7 @@ def __init__( # target self.target = target - if not isinstance(target, bm.Variable): + if not isinstance(target, Variable): raise ValueError(f'Must be an instance of brainpy.math.Variable. But we got {type(target)}') # delay length @@ -532,7 +533,7 @@ def __init__( # delay data if before_t0 is not None: - assert isinstance(before_t0, (int, float, bool, bm.Array, jax.Array, Callable)) + assert isinstance(before_t0, (int, float, bool, Array, jax.Array, Callable)) self._before_t0 = before_t0 if length > 0: self._init_data(length) @@ -541,7 +542,7 @@ def __init__( # time variables if self.method == ROTATE_UPDATE: - self.idx = bm.Variable(stop_gradient(jnp.asarray(0, dtype=jnp.int32))) + self.idx = Variable(stop_gradient(jnp.asarray(0, dtype=jnp.int32))) # other info self._access_to_step = dict() @@ -551,8 +552,8 @@ def __init__( def register_entry( self, entry: str, - delay_time: Optional[Union[float, bm.Array, Callable]] = None, - delay_step: Optional[Union[int, bm.Array, Callable]] = None, + delay_time: Optional[Union[float, Array, Callable]] = None, + delay_step: Optional[Union[int, Array, Callable]] = None, ) -> 'DelayVariable': """Register an entry to access the data. @@ -571,24 +572,24 @@ def register_entry( if delay_step is not None: raise ValueError('Provide either "delay_time" or "delay_step". Both you have given both.') if callable(delay_time): - delay_time = bm.as_jax(delay_time(self.delay_target_shape)) - delay_step = jnp.asarray(delay_time / bm.get_dt(), dtype=bm.get_int()) + delay_time = as_jax(delay_time(self.delay_target_shape)) + delay_step = jnp.asarray(delay_time / get_dt(), dtype=get_int()) elif isinstance(delay_time, float): - delay_step = int(delay_time / bm.get_dt()) + delay_step = int(delay_time / get_dt()) else: - delay_step = jnp.asarray(bm.as_jax(delay_time) / bm.get_dt(), dtype=bm.get_int()) + delay_step = jnp.asarray(as_jax(delay_time) / get_dt(), dtype=get_int()) # delay steps if delay_step is None: delay_type = 'none' elif isinstance(delay_step, int): delay_type = 'homo' - elif isinstance(delay_step, (bm.Array, jax.Array, np.ndarray)): + elif isinstance(delay_step, (Array, jax.Array, np.ndarray)): if delay_step.size == 1 and delay_step.ndim == 0: delay_type = 'homo' else: delay_type = 'heter' - delay_step = bm.Array(delay_step) + delay_step = Array(delay_step) elif callable(delay_step): delay_step = delay_step(self.delay_target_shape) delay_type = 'heter' @@ -617,7 +618,7 @@ def register_entry( self._access_to_step[entry] = delay_step return self - def at(self, entry: str, *indices) -> bm.Array: + def at(self, entry: str, *indices) -> Array: """Get the data at the given entry. Args: @@ -692,7 +693,7 @@ def retrieve(self, delay_step, *indices): # the delay data return self.data[indices] - def update(self, latest_value: Optional[Union[bm.Array, jax.Array]] = None) -> None: + def update(self, latest_value: Optional[Union[Array, jax.Array]] = None) -> None: """Update delay variable with the new data. """ if self.data is not None: @@ -702,13 +703,13 @@ def update(self, latest_value: Optional[Union[bm.Array, jax.Array]] = None) -> N # update the delay data at the rotation index if self.method == ROTATE_UPDATE: - self.idx.value = stop_gradient(bm.as_jax((self.idx - 1) % (self.length + 1))) + self.idx.value = stop_gradient(as_jax((self.idx - 1) % (self.length + 1))) self.data[self.idx.value] = latest_value # update the delay data at the first position elif self.method == CONCAT_UPDATE: if self.length >= 2: - self.data.value = bm.vstack([latest_value, self.data[1:]]) + self.data.value = vstack([latest_value, self.data[1:]]) else: self.data[0] = latest_value @@ -735,11 +736,11 @@ def _init_data(self, length, batch_size: int = None): batch_axis = None else: batch_axis = self.target.batch_axis + 1 - self.data = bm.Variable(jnp.zeros((length + 1,) + self.target.shape, dtype=self.target.dtype), + self.data = Variable(jnp.zeros((length + 1,) + self.target.shape, dtype=self.target.dtype), batch_axis=batch_axis) # update delay data self.data[0] = self.target.value - if isinstance(self._before_t0, (bm.Array, jax.Array, float, int, bool)): + if isinstance(self._before_t0, (Array, jax.Array, float, int, bool)): self.data[1:] = self._before_t0 elif callable(self._before_t0): self.data[1:] = self._before_t0((length,) + self.target.shape, dtype=self.target.dtype) From 96a06526e35e77ebbe5bbd53ccfb4aafaad70af9 Mon Sep 17 00:00:00 2001 From: chaoming Date: Wed, 15 Feb 2023 15:05:37 +0800 Subject: [PATCH 6/6] fix bugs in delay vars --- brainpy/_src/dyn/base.py | 5 ++-- brainpy/_src/dyn/runners.py | 4 ++-- brainpy/_src/experimental/delay.py | 4 ++-- brainpy/_src/math/context.py | 32 +++++++++++++------------- brainpy/_src/math/delayvars.py | 2 +- brainpy/_src/train/back_propagation.py | 3 ++- brainpy/math/activations.py | 1 - 7 files changed, 26 insertions(+), 25 deletions(-) diff --git a/brainpy/_src/dyn/base.py b/brainpy/_src/dyn/base.py index 17522466c..b3eba41aa 100644 --- a/brainpy/_src/dyn/base.py +++ b/brainpy/_src/dyn/base.py @@ -134,7 +134,7 @@ def __init__( self.local_delay_vars: Dict[str, bm.LengthDelay] = Collector() # super initialization - super(DynamicalSystem, self).__init__(name=name) + BrainPyObject.__init__(self, name=name) @property def mode(self) -> bm.Mode: @@ -155,7 +155,8 @@ def __call__(self, *args, **kwargs): """The shortcut to call ``update`` methods.""" if hasattr(self.update, '_new_style') and getattr(self.update, '_new_style'): if len(args) and isinstance(args[0], dict): - bm.share.save(**args[0]) + for k, v in args[0].items(): + bm.share.save(k, v) return self.update(*args[1:], **kwargs) else: return self.update(*args, **kwargs) diff --git a/brainpy/_src/dyn/runners.py b/brainpy/_src/dyn/runners.py index 939cc5b94..dd59e473c 100644 --- a/brainpy/_src/dyn/runners.py +++ b/brainpy/_src/dyn/runners.py @@ -615,7 +615,8 @@ def _step_func_predict(self, shared_args, t, i, x): # input step shared = tools.DotDict(t=t, i=i, dt=self.dt) shared.update(shared_args) - bm.share.save(**shared) + for k, v in shared.items(): + bm.share.save(k, v) self.target.clear_input() self._step_func_input(shared) @@ -630,7 +631,6 @@ def _step_func_predict(self, shared_args, t, i, x): # finally if self.progress_bar: id_tap(lambda *arg: self._pbar.update(), ()) - bm.share.remove_shargs() return out, mon def _get_f_predict(self, shared_args: Dict = None): diff --git a/brainpy/_src/experimental/delay.py b/brainpy/_src/experimental/delay.py index d4222e30b..35185ee9b 100644 --- a/brainpy/_src/experimental/delay.py +++ b/brainpy/_src/experimental/delay.py @@ -21,10 +21,11 @@ def __init__( length: int = 0, before_t0: Union[float, int, bool, bm.Array, jax.Array, Callable] = None, entries: Optional[Dict] = None, - method: str = None, + method: str = ROTATE_UPDATE, mode: bm.Mode = None, name: str = None, ): + DynamicalSystem.__init__(self, mode=mode) if method is None: if self.mode.is_a(bm.NonBatchingMode): method = ROTATE_UPDATE @@ -32,7 +33,6 @@ def __init__( method = CONCAT_UPDATE else: method = ROTATE_UPDATE - DynamicalSystem.__init__(self, mode=mode) DelayVariable.__init__(self, target=target, length=length, diff --git a/brainpy/_src/math/context.py b/brainpy/_src/math/context.py index 7283569cc..a4110901e 100644 --- a/brainpy/_src/math/context.py +++ b/brainpy/_src/math/context.py @@ -64,24 +64,24 @@ def save(self, identifier: str, data: Any) -> None: """Save shared arguments in the global context.""" assert isinstance(identifier, str) - if identifier in self._identifiers: - raise ValueError(f'{identifier} has been used. Please assign another name.') if isinstance(data, DelayVariable): + if identifier in self._identifiers: + raise ValueError(f'{identifier} has been used. Please assign another name.') self._delays[identifier] = data - elif isinstance(data, DelayEntry): - if isinstance(data.target, DelayVariable): - delay_key = f'delay{id(data)}' - self.save(delay_key, data.target) - delay = data.target - elif isinstance(data.target, str): - if data.target not in self._delays: - raise ValueError(f'Delay target {data.target} has not been registered.') - delay = self._delays[data.target] - delay_key = data.target - else: - raise ValueError(f'Unknown delay target. {type(data.target)}') - delay.register_entry(identifier, delay_time=data.time, delay_step=data.step) - self._delay_entries[identifier] = delay_key + # elif isinstance(data, DelayEntry): + # if isinstance(data.target, DelayVariable): + # delay_key = f'delay{id(data)}' + # self.save(delay_key, data.target) + # delay = data.target + # elif isinstance(data.target, str): + # if data.target not in self._delays: + # raise ValueError(f'Delay target {data.target} has not been registered.') + # delay = self._delays[data.target] + # delay_key = data.target + # else: + # raise ValueError(f'Unknown delay target. {type(data.target)}') + # delay.register_entry(identifier, delay_time=data.time, delay_step=data.step) + # self._delay_entries[identifier] = delay_key else: self._arguments[identifier] = data self._identifiers.add(identifier) diff --git a/brainpy/_src/math/delayvars.py b/brainpy/_src/math/delayvars.py index 62a71d0fd..fa0fd193e 100644 --- a/brainpy/_src/math/delayvars.py +++ b/brainpy/_src/math/delayvars.py @@ -519,7 +519,7 @@ def __init__( name: str = None, method: str = ROTATE_UPDATE, ): - super().__init__(name=name) + BrainPyObject.__init__(self, name=name) assert method in [ROTATE_UPDATE, CONCAT_UPDATE] self.method = method diff --git a/brainpy/_src/train/back_propagation.py b/brainpy/_src/train/back_propagation.py index 934009abc..ac76b93a5 100644 --- a/brainpy/_src/train/back_propagation.py +++ b/brainpy/_src/train/back_propagation.py @@ -566,7 +566,8 @@ def _step_func_fit(self, shared_args, inputs, targets): def _step_func_predict(self, shared, x=None): assert self.data_first_axis == 'B', f'There is no time dimension when using the trainer {self.__class__.__name__}.' - bm.share.save(**shared) + for k, v in shared.items(): + bm.share.save(k, v) # input step self.target.clear_input() diff --git a/brainpy/math/activations.py b/brainpy/math/activations.py index 655447068..0096090f5 100644 --- a/brainpy/math/activations.py +++ b/brainpy/math/activations.py @@ -24,5 +24,4 @@ swish as swish, selu as selu, identity as identity, - tanh as tanh, )