diff --git a/brainpy/_src/math/ndarray.py b/brainpy/_src/math/ndarray.py index 17493d79a..084ba8025 100644 --- a/brainpy/_src/math/ndarray.py +++ b/brainpy/_src/math/ndarray.py @@ -79,6 +79,11 @@ def _as_jax_array_(obj): return obj.value if isinstance(obj, Array) else obj +def _check_out(out): + if not isinstance(out, Array): + raise TypeError(f'out must be an instance of brainpy Array. But got {type(out)}') + + class Array(object): """Multiple-dimensional array in BrainPy. """ @@ -512,9 +517,16 @@ def choose(self, choices, mode='raise'): """Use an index array to construct a new array from a set of choices.""" return _return(self.value.choose(choices=_as_jax_array_(choices), mode=mode)) - def clip(self, min=None, max=None): + def clip(self, min=None, max=None, out=None, ): """Return an array whose values are limited to [min, max]. One of max or min must be given.""" - return _return(self.value.clip(min=min, max=max)) + min = _as_jax_array_(min) + max = _as_jax_array_(max) + r = self.value.clip(min=min, max=max) + if out is None: + return _return(r) + else: + _check_out(out) + out.value = r def compress(self, condition, axis=None): """Return selected slices of this array along given axis.""" @@ -999,11 +1011,11 @@ def view(self, *args, dtype=None): else: return _return(self.value.view(dtype)) else: - if isinstance(args[0], int): # shape + if isinstance(args[0], int): # shape if dtype is not None: raise ValueError('Provide one of dtype or shape. Not both.') return _return(self.value.reshape(*args)) - else: # dtype + else: # dtype assert not isinstance(args[0], int) assert dtype is None return _return(self.value.view(args[0])) @@ -1115,25 +1127,6 @@ def expand_dims(self, axis: Union[int, Sequence[int]]) -> 'Array': """ return Array(jnp.expand_dims(self.value, axis)) - def expand(self, *shape: Union[int, Sequence[int]]) -> 'Array': - """ - Expand an array to a new shape. - - Parameters - ---------- - shape : tuple or int - The shape of the desired array. A single integer ``i`` is interpreted - as ``(i,)``. - - Returns - ------- - expanded : Array - A readonly view on the original array with the given shape. It is - typically not contiguous. Furthermore, more than one element of a - expanded array may refer to a single memory location. - """ - return Array(jnp.broadcast_to(self._value, shape)) - def expand_as(self, array: Union['Array', jax.Array, np.ndarray]) -> 'Array': """ Expand an array to a shape of another array. @@ -1153,381 +1146,266 @@ def expand_as(self, array: Union['Array', jax.Array, np.ndarray]) -> 'Array': array = Array(array) return Array(jnp.broadcast_to(self.value, array.value.shape)) - def squeeze(self, - axis: Optional[Union[int, Sequence]]=None) -> 'Array': - return Array(self.squeeze(axis)) - # def item(self, *args) -> Any: # return self.value.item(*args) def pow(self, index: int): - return self._value ** index - - def addr(self, - vec1: Union['Array', jax.Array, np.ndarray], - vec2: Union['Array', jax.Array, np.ndarray], - *, - beta: float = 1.0, - alpha: float = 1.0, - out: Optional[Union['Array', jax.Array, np.ndarray]] = None) -> Union[None, NoReturn]: - if not isinstance(beta, int) and not isinstance(beta, float): - raise Exception('Wrong beta param of addr') - if not isinstance(alpha, int) and not isinstance(alpha, float): - raise Exception('Wrong alpha param of addr') - if not isinstance(vec1, Array): - vec1 = Array(vec1) - if not isinstance(vec2, Array): - vec2 = Array(vec2) - if not isinstance(out, Array): - out = Array(out) - return _return(brainpy.math.outer(vec1, vec2, out=out)) - - - def addr_(self, - vec1: Union['Array', jax.Array, np.ndarray], - vec2: Union['Array', jax.Array, np.ndarray], - *, - beta: float = 1.0, - alpha: float = 1.0) -> Union['Array', NoReturn]: - if not isinstance(beta, (int,float)): - raise Exception('Wrong beta param of addr') - if not isinstance(alpha, (int,float)): - raise Exception('Wrong alpha param of addr') - if not isinstance(vec1, Array): - vec1 = Array(vec1) - if not isinstance(vec2, Array): - vec2 = Array(vec2) - # self.value *= beta - # self.value += alpha * jnp.outer(vec1, vec2) - return brainpy.math.outer(vec1, vec2, out=self) - - def outer(self, other: Union['Array', jax.Array, np.ndarray]) -> Union[NoReturn, None]: - # if other is None: - # raise Exception('Array can not make outer product with None') - if not isinstance(other, Array): - other = Array(other) - return _return(jnp.outer(self.value, other.value)) + return _return(self._value ** index) - def sum(self) -> 'Array': - return _return(self.value.sum()) + def addr( + self, + vec1: Union['Array', jax.Array, np.ndarray], + vec2: Union['Array', jax.Array, np.ndarray], + *, + beta: float = 1.0, + alpha: float = 1.0, + out: Optional[Union['Array', jax.Array, np.ndarray]] = None + ) -> Optional['Array']: + r"""Performs the outer-product of vectors ``vec1`` and ``vec2`` and adds it to the matrix ``input``. + + Optional values beta and alpha are scaling factors on the outer product + between vec1 and vec2 and the added matrix input respectively. - def abs(self, *, out: Optional[Union['Array', jax.Array, np.ndarray]] = None) -> 'Array': - # return Array(self.value.__abs__()) - abs_value = None - if out is not None: - if not isinstance(out, Array): - out = Array(out) - abs_value = brainpy.math.abs(self.value, out=out) + .. math:: + + out = \beta \mathrm{input} + \alpha (\text{vec1} \bigtimes \text{vec2}) + + Args: + vec1: the first vector of the outer product + vec2: the second vector of the outer product + beta: multiplier for input + alpha: multiplier + out: the output tensor. + + """ + vec1 = _as_jax_array_(vec1) + vec2 = _as_jax_array_(vec2) + r = alpha * jnp.outer(vec1, vec2) + beta * self.value + if out is None: + return _return(r) + else: + _check_out(out) + out.value = r + + def addr_( + self, + vec1: Union['Array', jax.Array, np.ndarray], + vec2: Union['Array', jax.Array, np.ndarray], + *, + beta: float = 1.0, + alpha: float = 1.0 + ) -> None: + vec1 = _as_jax_array_(vec1) + vec2 = _as_jax_array_(vec2) + r = alpha * jnp.outer(vec1, vec2) + beta * self.value + self.value = r + + def outer(self, other: Union['Array', jax.Array, np.ndarray]) -> 'Array': + other = _as_jax_array_(other) + return _return(jnp.outer(self.value, other.value)) + + def abs(self, *, out: Optional[Union['Array', jax.Array, np.ndarray]] = None) -> Optional['Array']: + r = jnp.abs(self.value) + if out is None: + return _return(r) else: - abs_value = brainpy.math.abs(self.value) - # if isinstance(out, (Array, jax.Array, np.ndarray)): - # out.value = abs_value - return _return(abs_value) + _check_out(out) + out.value = r - def abs_(self) -> 'Array': + def abs_(self) -> None: """ in-place version of Array.abs() """ - return brainpy.math.abs(self, out=self) + self.value = jnp.abs(self.value) - def absolute(self, *, out: Optional[Union['Array', jax.Array, np.ndarray]] = None) -> 'Array': + def absolute(self, *, out: Optional[Union['Array', jax.Array, np.ndarray]] = None) -> Optional['Array']: """ alias of Array.abs """ - if not isinstance(out, Array): - out = Array(out) return self.abs(out=out) - def absolute_(self) -> 'Array': + def absolute_(self) -> None: """ alias of Array.abs_() """ return self.abs_() - def sin(self, *, out: Optional[Union['Array', jax.Array, np.ndarray]] = None) -> Union['Array', NoReturn]: - # return Array(self.value.__abs__()) - # if out is not None: - # if not isinstance(out, (Array, jax.Array, np.ndarray)): - # raise Exception('Unexcepted param out') - value = None - if out is not None: - if not isinstance(out, Array): - out = Array(out) - value = brainpy.math.sin(self.value, out=out) + def sin(self, *, out: Optional[Union['Array', jax.Array, np.ndarray]] = None) -> Optional['Array']: + r = jnp.sin(self.value) + if out is None: + return _return(r) else: - value = brainpy.math.sin(self.value) - return Array(value) - - def sin_(self) -> 'Array': - return Array(brainpy.math.sin(self.value, out=self)) - - def cos_(self) -> 'Array': - return Array(brainpy.math.cos(self.value, out=self)) - - def cos(self, *, out: Optional[Union['Array', jax.Array, np.ndarray]] = None) -> Union['Array', NoReturn]: - # return Array(self.value.__abs__()) - # if out is not None: - # if not isinstance(out, (Array, jax.Array, np.ndarray)): - # raise Exception('Unexcepted param out') - value = None - if out is not None: - if not isinstance(out, Array): - out = Array(out) - value = brainpy.math.cos(self.value, out=out.value) + _check_out(out) + out.value = r + + def sin_(self) -> None: + self.value = jnp.sin(self.value) + + def cos_(self) -> None: + self.value = jnp.cos(self.value) + + def cos(self, *, out: Optional[Union['Array', jax.Array, np.ndarray]] = None) -> Optional['Array']: + r = jnp.cos(self.value) + if out is None: + return _return(r) + else: + _check_out(out) + out.value = r + + def tan_(self) -> None: + self.value = jnp.tan(self.value) + + def tan(self, *, out: Optional[Union['Array', jax.Array, np.ndarray]] = None) -> Optional['Array']: + r = jnp.tan(self.value) + if out is None: + return _return(r) else: - value = brainpy.math.cos(self.value) - return Array(value) - - def tan_(self) -> 'Array': - return Array(brainpy.math.tan(self.value, out=self.value)) - - def tan(self, *, out: Optional[Union['Array', jax.Array, np.ndarray]] = None) -> Union['Array', NoReturn]: - # return Array(self.value.__abs__()) - # if out is not None: - # if not isinstance(out, (Array, jax.Array, np.ndarray)): - # raise Exception('Unexcepted param out') - value = None - if out is not None: - if not isinstance(out, Array): - out = Array(out) - value = brainpy.math.tan(self.value, out=out.value) + _check_out(out) + out.value = r + + def sinh_(self) -> None: + self.value = jnp.tanh(self.value) + + def sinh(self, *, out: Optional[Union['Array', jax.Array, np.ndarray]] = None) -> Optional['Array']: + r = jnp.tanh(self.value) + if out is None: + return _return(r) + else: + _check_out(out) + out.value = r + + def cosh_(self) -> None: + self.value = jnp.cosh(self.value) + + def cosh(self, *, out: Optional[Union['Array', jax.Array, np.ndarray]] = None) -> Optional['Array']: + r = jnp.cosh(self.value) + if out is None: + return _return(r) else: - value = brainpy.math.tan(self.value) - return Array(value) - - def sinh_(self) -> 'Array': - return Array(brainpy.math.sinh(self.value, out=self.value)) - - def sinh(self, *, out: Optional[Union['Array', jax.Array, np.ndarray]] = None) -> Union['Array', NoReturn]: - # return Array(self.value.__abs__()) - # if out is not None: - # if not isinstance(out, (Array, jax.Array, np.ndarray)): - # raise Exception('Unexcepted param out') - value = None - if out is not None: - if not isinstance(out, Array): - out = Array(out) - value = brainpy.math.sinh(self.value, out=out.value) + _check_out(out) + out.value = r + + def tanh_(self) -> None: + self.value = jnp.tanh(self.value) + + def tanh(self, *, out: Optional[Union['Array', jax.Array, np.ndarray]] = None) -> Optional['Array']: + r = jnp.tanh(self.value) + if out is None: + return _return(r) else: - value = brainpy.math.sinh(self.value) - return Array(value) - - def cosh_(self) -> 'Array': - return Array(brainpy.math.cosh(self.value, out=self.value)) - - def cosh(self, *, out: Optional[Union['Array', jax.Array, np.ndarray]] = None) -> Union['Array', NoReturn]: - # return Array(self.value.__abs__()) - # if out is not None: - # if not isinstance(out, (Array, jax.Array, np.ndarray)): - # raise Exception('Unexcepted param out') - if not isinstance(out, Array): - out = Array(out) - return Array(brainpy.math.cosh(self.value, out=out.value)) - - def tanh_(self) -> 'Array': - return Array(brainpy.math.tanh(self.value, out=self.value)) - - def tanh(self, *, out: Optional[Union['Array', jax.Array, np.ndarray]] = None) -> Union['Array', NoReturn]: - # return Array(self.value.__abs__()) - # if out is not None: - # if not isinstance(out, (Array, jax.Array, np.ndarray)): - # raise Exception('Unexcepted param out') - value = None - if out is not None: - if not isinstance(out, Array): - out = Array(out) - value = brainpy.math.tanh(self.value, out=out.value) + _check_out(out) + out.value = r + + def arcsin_(self) -> None: + self.value = jnp.arcsin(self.value) + + def arcsin(self, *, out: Optional[Union['Array', jax.Array, np.ndarray]] = None) -> Optional['Array']: + r = jnp.arcsin(self.value) + if out is None: + return _return(r) else: - value = brainpy.math.tanh(self.value) - return Array(value) - - def arcsin_(self) -> 'Array': - return Array(brainpy.math.arcsin(self.value, out=self.value)) - - def arcsin(self, *, out: Optional[Union['Array', jax.Array, np.ndarray]] = None) -> Union['Array', NoReturn]: - # return Array(self.value.__abs__()) - # if out is not None: - # if not isinstance(out, (Array, jax.Array, np.ndarray)): - # raise Exception('Unexcepted param out') - if not isinstance(out, Array): - out = Array(out) - return Array(brainpy.math.arcsin(self.value, out=out.value)) - - def arccos_(self) -> 'Array': - return Array(brainpy.math.arccos(self.value, out=self.value)) - - def arccos(self, *, out: Optional[Union['Array', jax.Array, np.ndarray]] = None) -> Union['Array', NoReturn]: - # return Array(self.value.__abs__()) - # if out is not None: - # if not isinstance(out, (Array, jax.Array, np.ndarray)): - # raise Exception('Unexcepted param out') - value = None - if out is not None: - if not isinstance(out, Array): - out = Array(out) - value = brainpy.math.arccos(self.value, out=out.value) + _check_out(out) + out.value = r + + def arccos_(self) -> None: + self.value = jnp.arccos(self.value) + + def arccos(self, *, out: Optional[Union['Array', jax.Array, np.ndarray]] = None) -> Optional['Array']: + r = jnp.arccos(self.value) + if out is None: + return _return(r) else: - value = brainpy.math.arccos(self.value) - return Array(value) - - def arctan_(self) -> 'Array': - return Array(brainpy.math.arctan(self.value, out=self.value)) - - def arctan(self, *, out: Optional[Union['Array', jax.Array, np.ndarray]] = None) -> Union['Array', NoReturn]: - # return Array(self.value.__abs__()) - # if out is not None: - # if not isinstance(out, (Array, jax.Array, np.ndarray)): - # raise Exception('Unexcepted param out') - value = None - if out is not None: - if not isinstance(out, Array): - out = Array(out) - value = brainpy.math.arctan(self.value, out=out.value) + _check_out(out) + out.value = r + + def arctan_(self) -> None: + self.value = jnp.arctan(self.value) + + def arctan(self, *, out: Optional[Union['Array', jax.Array, np.ndarray]] = None) -> Optional['Array']: + r = jnp.arctan(self.value) + if out is None: + return _return(r) else: - value = brainpy.math.arctan(self.value) - return Array(value) - - # def all(self, - # axis: Optional[int] = None, - # keepdim: bool = False, - # *, - # out: Optional[Union['Array', jax.Array, np.ndarray]] = None): - # """ - # test if all element cast to true - # """ - # # if out is not None: - # # if not isinstance(out, (Array, jax.Array, np.ndarray)): - # # raise Exception('Unexcepted param out') - # value = value = brainpy.math.all(self.value, axis, keepdim) - # if out is not None: - # if not isinstance(out, Array): - # warnings.showwarning("out is not a brainpy Array") - # out = Array(out) - # out.update(value) - # return Array(value) - - # def any(self, - # dim: int, - # keepdim: bool, - # *, - # out: Optional[Union['Array', jax.Array, np.ndarray]] = None): - # """ - # test if any element cast to true - # """ - # value = value = brainpy.math.any(self.value) - # if out is not None: - # if not isinstance(out, Array): - # warnings.showwarning("out is not a brainpy Array") - # out = Array(out) - # out.update(value) - # return Array(value) - - def clamp(self, - min_value: Optional[Union['Array', jax.Array, np.ndarray]] = None, - max_value: Optional[Union['Array', jax.Array, np.ndarray]] = None, - *, - out: Optional[Union['Array', jax.Array, np.ndarray]] = None) -> 'Array': + _check_out(out) + out.value = r + + def clamp( + self, + min_value: Optional[Union['Array', jax.Array, np.ndarray]] = None, + max_value: Optional[Union['Array', jax.Array, np.ndarray]] = None, + *, + out: Optional[Union['Array', jax.Array, np.ndarray]] = None + ) -> Optional['Array']: """ return the value between min_value and max_value, if min_value is None, then no lower bound, if max_value is None, then no upper bound. """ - # if out is not None: - # if not isinstance(out, (Array, jax.Array, np.ndarray)): - # raise Exception('Unexcepted param out') - - # value = None - # if out is not None: - # if not isinstance(out, Array): - # out = Array(out) - # value = brainpy.math.clip(self.value, min_value, max_value, out=out) - # else: - # value = brainpy.math.clip(self.value, min_value, max_value) - # return Array(value) - - return _return(self.value.clip(min_value, max_value)) + min_value = _as_jax_array_(min_value) + max_value = _as_jax_array_(max_value) + r = jnp.clip(self.value, max_value, max_value) + if out is None: + return _return(r) + else: + _check_out(out) + out.value = r def clamp_(self, - min_value: Optional[Union['Array', jax.Array, np.ndarray]] = None, - max_value: Optional[Union['Array', jax.Array, np.ndarray]] = None) -> 'Array': + min_value: Optional[Union['Array', jax.Array, np.ndarray]] = None, + max_value: Optional[Union['Array', jax.Array, np.ndarray]] = None) -> None: """ return the value between min_value and max_value, if min_value is None, then no lower bound, if max_value is None, then no upper bound. """ - return brainpy.math.clip(self.value, min_value, max_value, out=self) + self.clamp(min_value, max_value, out=self) def clip_(self, - min_value: Optional[Union['Array', jax.Array, np.ndarray]] = None, - max_value: Optional[Union['Array', jax.Array, np.ndarray]] = None) -> 'Array': - """ - alias for clamp_ - """ - return Array(brainpy.math.clip(self.value, min_value, max_value, out=self)) - - def clip(self, min_value: Optional[Union['Array', jax.Array, np.ndarray]] = None, - max_value: Optional[Union['Array', jax.Array, np.ndarray]] = None, - *, - out: Optional[Union['Array', jax.Array, np.ndarray]] = None) -> 'Array': + max_value: Optional[Union['Array', jax.Array, np.ndarray]] = None) -> None: """ - alias for clamp + alias for clamp_ """ - # if out is not None: - # if not isinstance(out, (Array, jax.Array, np.ndarray)): - # raise Exception('Unexcepted param out') - value = None - if out is not None: - if not isinstance(out, Array): - out = Array(out) - value = brainpy.math.clip(self.value, out=out) - else: - value = brainpy.math.clip(self.value) - return Array(value) + return self.clip(min_value, max_value, out=self) def clone(self) -> 'Array': - return Array(brainpy.math.copy(self.value)) - - def copy_(self, src: Union['Array', jax.Array, np.ndarray]) -> 'Array': - value = None - if src is not None: - if not isinstance(src, Array): - src = Array(src) - value = brainpy.math.copyto(self.value, src) - else: - raise Exception("copy from None???") - return self - - # def conj(self) -> 'Array': - # return Array(brainpy.math.conj(self.value)) - - def cov_with(self, - y: Optional[Union['Array', jax.Array, np.ndarray]] = None, - rowvar: bool = True, - bias: bool = False, - ddof: Optional[int] = None, - fweights: Union['Array', jax.Array, np.ndarray] = None, - aweights: Union['Array', jax.Array, np.ndarray] = None) -> 'Array': - return Array(brainpy.math.cov(self.value, y, rowvar, bias, fweights, aweights)) - - def cov(self, - *, - correction: int = 1, - fweights: Union['Array', jax.Array, np.ndarray] = None, - aweights: Union['Array', jax.Array, np.ndarray] = None) -> Union['Array', NoReturn]: - try: - x = [e[0] for e in self.value] - y = [e[1] for e in self.value] - return Array(brainpy.math.cov(x, y, ddof=correction, fweights=fweights, aweights=aweights)) - except Exception as e: - raise Exception('Wrong format, need to be [[x1,y1],[x2,y2],[x3,y3]]') + return Array(self.value.copy()) + def copy_(self, src: Union['Array', jax.Array, np.ndarray]) -> None: + self.value = jnp.copy(_as_jax_array_(src)) - # ------------------ - # Torch support - # ------------------ + def cov_with( + self, + y: Optional[Union['Array', jax.Array, np.ndarray]] = None, + rowvar: bool = True, + bias: bool = False, + ddof: Optional[int] = None, + fweights: Union['Array', jax.Array, np.ndarray] = None, + aweights: Union['Array', jax.Array, np.ndarray] = None + ) -> 'Array': + y = _as_jax_array_(y) + fweights = _as_jax_array_(fweights) + aweights = _as_jax_array_(aweights) + r = jnp.cov(self.value, y, rowvar, bias, fweights, aweights) + return Array(r) def expand(self, *sizes) -> 'Array': + """ + Expand an array to a new shape. + + Parameters + ---------- + shape : tuple or int + The shape of the desired array. A single integer ``i`` is interpreted + as ``(i,)``. + + Returns + ------- + expanded : Array + A readonly view on the original array with the given shape. It is + typically not contiguous. Furthermore, more than one element of a + expanded array may refer to a single memory location. + """ l_ori = len(self.shape) l_tar = len(sizes) base = l_tar - l_ori @@ -1537,12 +1415,14 @@ def expand(self, *sizes) -> 'Array': f'dimensions in the tensor ({len(self.shape)})') for i, v in enumerate(sizes[:base]): if v < 0: - raise ValueError(f'The expanded size of the tensor ({v}) isn\'t allowed in a leading, non-existing dimension {i + 1}') + raise ValueError( + f'The expanded size of the tensor ({v}) isn\'t allowed in a leading, non-existing dimension {i + 1}') for i, v in enumerate(self.shape): sizes_list[base + i] = v if sizes_list[base + i] == -1 else sizes_list[base + i] if v != 1 and sizes_list[base + i] != v: - raise ValueError(f'The expanded size of the tensor ({sizes_list[base + i]}) must match the existing size ({v}) at non-singleton ' - f'dimension {i}. Target sizes: {sizes}. Tensor sizes: {self.shape}') + raise ValueError( + f'The expanded size of the tensor ({sizes_list[base + i]}) must match the existing size ({v}) at non-singleton ' + f'dimension {i}. Target sizes: {sizes}. Tensor sizes: {self.shape}') return Array(jnp.broadcast_to(self.value, sizes_list))