From 397c3c0effff8e62462583c392e96273808f74f9 Mon Sep 17 00:00:00 2001 From: chaoming Date: Mon, 26 Dec 2022 21:13:44 +0800 Subject: [PATCH 1/6] improve transformations --- brainpy/math/object_transform/autograd.py | 768 +++++++--------------- brainpy/math/object_transform/controls.py | 22 +- 2 files changed, 253 insertions(+), 537 deletions(-) diff --git a/brainpy/math/object_transform/autograd.py b/brainpy/math/object_transform/autograd.py index b726201f7..ae7aafaa7 100644 --- a/brainpy/math/object_transform/autograd.py +++ b/brainpy/math/object_transform/autograd.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- -from functools import partial -from typing import Union, Callable, Dict, Sequence, Any +from functools import partial, wraps +from typing import Union, Callable, Dict, Sequence, Any, Optional import jax import numpy as np @@ -12,13 +12,14 @@ _check_output_dtype_jacfwd, _check_input_dtype_jacfwd, ) from jax.api_util import argnums_partial from jax.errors import UnexpectedTracerError -from jax.tree_util import tree_flatten, tree_unflatten, tree_map, tree_transpose, tree_structure +from jax.tree_util import (tree_flatten, tree_unflatten, + tree_map, tree_transpose, + tree_structure) from jax.util import safe_map -from brainpy import errors, tools -from brainpy.base import get_unique_name, ArrayCollector -from brainpy.math.ndarray import Array, add_context, del_context -from ._utils import infer_dyn_vars +from brainpy import errors, tools, check +from brainpy.base import BrainPyObject +from brainpy.math.ndarray import Array, Variable, add_context, del_context from .base import ObjectTransform __all__ = [ @@ -29,65 +30,99 @@ ] -class GradientFunTransform(ObjectTransform): - _excluded_vars = ('_origin_fun',) +class GradientTransform(ObjectTransform): + """Object-oriented Automatic Differentiation Transformation in BrainPy. + """ def __init__( self, - grad_func: Callable, - dyn_vars: Any, - grad_vars: Any, - name: str = None, - origin_fun=None - ): - super().__init__(name=name) - - self._origin_fun = origin_fun - self._f = grad_func - self.register_implicit_vars(dyn_vars, grad_vars) + target: Callable, + transform: Callable, - def __call__(self, *args, **kwargs): - return self._f(*args, **kwargs) - - def __repr__(self): - name = self.__class__.__name__ - f = tools.repr_object(self._origin_fun) - f = tools.repr_context(f, " " * (len(name) + 6)) - return f'{name}(target={f})' - - -class GradientTransform(ObjectTransform): - _excluded_vars = ('_origin_fun',) + # variables and nodes + grad_vars: Any, + dyn_vars: Dict[str, Variable], + child_objs: Dict[str, Variable], - def __init__( - self, - grad_func: Callable, - grad_tree, - grad_vars, - dyn_vars, - argnums, + # gradient setting + argnums: Optional[Union[int, Sequence[int]]], return_value: bool, has_aux: bool, + transform_setting: Optional[Dict[str, Any]] = None, + + # other name: str = None, - origin_fun=None ): super().__init__(name=name) - self.register_implicit_vars(dyn_vars, grad_vars) - self._grad_func = grad_func - self._grad_tree = grad_tree - self._grad_vars = grad_vars - self._dyn_vars = dyn_vars - self._argnums = argnums - self._return_value = return_value - self._has_aux = has_aux - self._origin_fun = origin_fun - - self.register_implicit_vars(dyn_vars, grad_vars) + # gradient variables + self._grad_vars, self._grad_tree = tree_flatten(grad_vars, is_leaf=lambda a: isinstance(a, Array)) + + # register variables and nodes + self.register_implicit_vars(dyn_vars, self._grad_vars) + self.register_implicit_nodes(child_objs) + + # parameters + if argnums is None and len(self._grad_vars) == 0: + argnums = 0 + if argnums is None: + assert len(self._grad_vars) > 0 + _argnums = 0 + elif isinstance(argnums, int): + _argnums = (0, argnums + 2) if len(self._grad_vars) > 0 else (argnums + 2) + else: + _argnums = check.is_sequence(argnums, elem_type=int, allow_none=False) + _argnums = tuple(a + 2 for a in _argnums) + if len(self._grad_vars) > 0: + _argnums = (0,) + _argnums + self.nonvar_argnums = argnums + self.return_value = return_value + self.has_aux = has_aux + + # target and transform + self.target = target + self.transform = transform + self._dyn_vars = tuple((self.vars().unique() - self._grad_vars).values()) + + # settings + transform_setting = dict() if transform_setting is None else transform_setting + if self.has_aux: + self._call = transform(self._f_grad_with_aux_to_transform, + argnums=_argnums, + has_aux=True, + **transform_setting) + else: + self._call = transform(self._f_grad_without_aux_to_transform, + argnums=_argnums, + has_aux=True, + **transform_setting) + + def _f_grad_with_aux_to_transform(self, grad_values, dyn_values, *args, **kwargs): + for v, d in zip(self._dyn_vars, dyn_values): v._value = d + for v, d in zip(self._grad_vars, grad_values): v._value = d + # Users should return the auxiliary data like:: + # >>> # 1. example of return one data + # >>> return scalar_loss, data + # >>> # 2. example of return multiple data + # >>> return scalar_loss, (data1, data2, ...) + outputs = self.target(*args, **kwargs) + # outputs: [0] is the value for gradient, + # [1] is other values for return + output0 = tree_map(lambda a: (a.value if isinstance(a, Array) else a), outputs[0]) + return output0, (outputs, [v.value for v in self._grad_vars], [v.value for v in self._dyn_vars]) + + def _f_grad_without_aux_to_transform(self, grad_values, dyn_values, *args, **kwargs): + for v, d in zip(self._dyn_vars, dyn_values): v._value = d + for v, d in zip(self._grad_vars, grad_values): v._value = d + # Users should return the scalar value like this:: + # >>> return scalar_loss + output = self.target(*args, **kwargs) + output0 = tree_map(lambda a: (a.value if isinstance(a, Array) else a), output) + return output0, (output, [v.value for v in self._grad_vars], [v.value for v in self._dyn_vars]) def __repr__(self): name = self.__class__.__name__ - f = tools.repr_object(self._origin_fun) + f = tools.repr_object(self.target) f = tools.repr_context(f, " " * (len(name) + 6)) format_ref = (f'{name}(target={f}, \n' + f'{" " * len(name)} num_of_grad_vars={len(self._grad_vars)}, \n' @@ -99,16 +134,16 @@ def __call__(self, *args, **kwargs): old_dyn_vs = [v.value for v in self._dyn_vars] try: add_context(self.name) - grads, (outputs, new_grad_vs, new_dyn_vs) = self._grad_func(old_grad_vs, - old_dyn_vs, - *args, - **kwargs) + grads, (outputs, new_grad_vs, new_dyn_vs) = self._call(old_grad_vs, + old_dyn_vs, + *args, + **kwargs) del_context(self._name) except UnexpectedTracerError as e: del_context(self._name) for v, d in zip(self._grad_vars, old_grad_vs): v._value = d for v, d in zip(self._dyn_vars, old_dyn_vs): v._value = d - raise errors.JaxTracerError(variables=self._dyn_vars + self._grad_vars) from e + raise errors.JaxTracerError() from e except Exception as e: del_context(self._name) for v, d in zip(self._grad_vars, old_grad_vs): v._value = d @@ -119,174 +154,41 @@ def __call__(self, *args, **kwargs): for v, d in zip(self._dyn_vars, new_dyn_vs): v._value = d # check returned grads - if len(self._grad_vars) == 0: - grads = grads[1] if isinstance(self._argnums, int) else grads[1:] - else: - var_grads = self._grad_tree.unflatten(grads[0]) - if self._argnums is None: - grads = var_grads + if len(self._grad_vars) > 0: + if self.nonvar_argnums is None: + grads = self._grad_tree.unflatten(grads) else: - arg_grads = grads[1] if isinstance(self._argnums, int) else grads[1:] + var_grads = self._grad_tree.unflatten(grads[0]) + arg_grads = grads[1] if isinstance(self.nonvar_argnums, int) else grads[1:] grads = (var_grads, arg_grads) # check returned value - if self._return_value: + if self.return_value: # check aux - if self._has_aux: + if self.has_aux: return grads, outputs[0], outputs[1] else: return grads, outputs else: # check aux - if self._has_aux: + if self.has_aux: return grads, outputs[1] else: return grads -def _make_cls_call_func(grad_func, grad_tree, grad_vars, dyn_vars, - argnums, return_value, has_aux): - name = get_unique_name('_brainpy_object_oriented_grad_') - - # outputs - def call_func(*args, **kwargs): - old_grad_vs = [v.value for v in grad_vars] - old_dyn_vs = [v.value for v in dyn_vars] - try: - add_context(name) - grads, (outputs, new_grad_vs, new_dyn_vs) = grad_func(old_grad_vs, - old_dyn_vs, - *args, - **kwargs) - del_context(name) - except UnexpectedTracerError as e: - del_context(name) - for v, d in zip(grad_vars, old_grad_vs): v._value = d - for v, d in zip(dyn_vars, old_dyn_vs): v._value = d - raise errors.JaxTracerError(variables=dyn_vars + grad_vars) from e - except Exception as e: - del_context(name) - for v, d in zip(grad_vars, old_grad_vs): v._value = d - for v, d in zip(dyn_vars, old_dyn_vs): v._value = d - raise e - for v, d in zip(grad_vars, new_grad_vs): v._value = d - for v, d in zip(dyn_vars, new_dyn_vs): v._value = d - - # check returned grads - if len(grad_vars) == 0: - grads = grads[1] if isinstance(argnums, int) else grads[1:] - else: - var_grads = grad_tree.unflatten(grads[0]) - if argnums is None: - grads = var_grads - else: - arg_grads = grads[1] if isinstance(argnums, int) else grads[1:] - grads = (var_grads, arg_grads) - - # check returned value - if return_value: - # check aux - if has_aux: - return grads, outputs[0], outputs[1] - else: - return grads, outputs - else: - # check aux - if has_aux: - return grads, outputs[1] - else: - return grads - - return call_func - - -def _check_vars(variables): - if variables is None: - vars, tree = tree_flatten(variables, is_leaf=lambda a: isinstance(a, Array)) - return vars, tree - if isinstance(variables, dict): - variables = dict(variables) - elif isinstance(variables, (list, tuple)): - variables = tuple(variables) - elif isinstance(variables, Array): - pass - else: - raise ValueError - vars, tree = tree_flatten(variables, is_leaf=lambda a: isinstance(a, Array)) - for v in vars: - if not isinstance(v, Array): - raise ValueError(f'"dyn_vars" and "grad_vars" only supports dict ' - f'of Array, but got {type(v)}: {v}') - return vars, tree - - -def _grad_checking(func: Callable, - dyn_vars: Union[Dict, Sequence], - grad_vars: Union[Dict, Sequence]): - # check function - if not callable(func): - raise ValueError(f'Must be a callable object. But we got {func}') - - # check "vars", make sure it is an instance of ArrayCollector - dyn_vars, _ = _check_vars(dyn_vars) - grad_vars, grad_tree = _check_vars(grad_vars) - - # check the duplicate in "dyn_vars" and "grad_vars" - dyn_vars = tuple(ArrayCollector.from_other(dyn_vars).unique().values()) - new_dyn_vars = [] - _dyn_var_ids = set([id(v) for v in grad_vars]) - for v in dyn_vars: - if id(v) not in _dyn_var_ids: - new_dyn_vars.append(v) - _dyn_var_ids.add(id(v)) - return tuple(new_dyn_vars), grad_vars, grad_tree - - -def _cls_grad(func, grad_vars, dyn_vars, argnums, has_aux=False, - holomorphic=False, allow_int=False, reduce_axes=()): - # parameters - assert isinstance(dyn_vars, (tuple, list)) # tuple/list of Variable - assert isinstance(grad_vars, (tuple, list)) # tuple/list of Variable - assert isinstance(argnums, (tuple, list)) # tuple/list of int - - # gradient functions - if has_aux: - @partial(jax.grad, argnums=argnums, has_aux=True, holomorphic=holomorphic, - allow_int=allow_int, reduce_axes=reduce_axes) - def grad_func(grad_values, dyn_values, *args, **kwargs): - for v, d in zip(dyn_vars, dyn_values): v._value = d - for v, d in zip(grad_vars, grad_values): v._value = d - # Users should return the auxiliary data like:: - # >>> # 1. example of return one data - # >>> return scalar_loss, data - # >>> # 2. example of return multiple data - # >>> return scalar_loss, (data1, data2, ...) - outputs = func(*args, **kwargs) - # outputs: [0] is the value for gradient, - # [1] is other values for return - output = outputs[0].value if isinstance(outputs[0], Array) else outputs[0] - return output, (outputs, [v.value for v in grad_vars], [v.value for v in dyn_vars]) - else: - @partial(jax.grad, - argnums=argnums, has_aux=True, holomorphic=holomorphic, - allow_int=allow_int, reduce_axes=reduce_axes) - def grad_func(grad_values, dyn_values, *args, **kwargs): - for v, d in zip(dyn_vars, dyn_values): v._value = d - for v, d in zip(grad_vars, grad_values): v._value = d - # Users should return the scalar value like this:: - # >>> return scalar_loss - output = func(*args, **kwargs) - output2 = output.value if isinstance(output, Array) else output - return output2, (output, [v.value for v in grad_vars], [v.value for v in dyn_vars]) - - return grad_func - - def grad( - func, grad_vars=None, dyn_vars=None, argnums=None, holomorphic=False, - allow_int=False, reduce_axes=(), has_aux=None, return_value=False, - auto_infer=True -) -> ObjectTransform: + func: Callable, + grad_vars: Optional[Union[Variable, Sequence[Variable], Dict[str, Variable]]] = None, + dyn_vars: Optional[Union[Variable, Sequence[Variable], Dict[str, Variable]]] = None, + child_objs: Optional[Union[BrainPyObject, Sequence[BrainPyObject], Dict[str, BrainPyObject]]] = None, + argnums: Optional[Union[int, Sequence[int]]] = None, + holomorphic: Optional[bool] = False, + allow_int: Optional[bool] = False, + reduce_axes: Optional[Sequence[str]] = (), + has_aux: Optional[bool] = None, + return_value: Optional[bool] = False, +) -> GradientTransform: """Automatic gradient computation for functions or class objects. This gradient function only support scalar return. It creates a function @@ -369,6 +271,9 @@ def grad( (which includes arrays with shape ``()`` but not arrays with shape ``(1,)`` etc.) dyn_vars : optional, ArrayType, sequence of ArrayType, dict The dynamically changed variables used in ``func``. + child_objs: optional, BrainPyObject, sequnce, dict + + .. versionadded:: 2.3.1 grad_vars : optional, ArrayType, sequence of ArrayType, dict The variables in ``func`` to take their gradients. argnums : optional, integer or sequence of integers @@ -394,13 +299,10 @@ def grad( is a named batch axis, ``grad(f, reduce_axes=('batch',))`` will create a function that computes the total gradient while ``grad(f)`` will create one that computes the per-example gradient. - auto_infer: bool - Automatically infer all ``Variable`` instances used in the target. - Returns ------- - func : ObjectTransform + func : GradientTransform A function with the same arguments as ``fun``, that evaluates the gradient of ``fun``. If ``argnums`` is an integer then the gradient has the same shape and type as the positional argument indicated by that integer. If @@ -408,74 +310,20 @@ def grad( same shapes and types as the corresponding arguments. If ``has_aux`` is True then a pair of (gradient, auxiliary_data) is returned. """ - - if dyn_vars is None: - dyn_vars = infer_dyn_vars(func) if auto_infer else dict() - - dyn_vars, grad_vars, grad_tree = _grad_checking(func, dyn_vars, grad_vars) - # dyn_vars -> ArrayCollector - # grad_vars -> ArrayCollector - has_aux = False if has_aux is None else has_aux - - # gradient - if len(dyn_vars) == 0 and len(grad_vars) == 0: - argnums = 0 if argnums is None else argnums - if return_value: - grad_func = jax.value_and_grad(fun=func, - argnums=argnums, - has_aux=has_aux, - holomorphic=holomorphic, - allow_int=allow_int, - reduce_axes=reduce_axes) - - def call_func(*args, **kwargs): - result = grad_func(*args, **kwargs) - if has_aux: - (ans, aux), g = result - return g, ans, aux - else: - ans, g = result - return g, ans - - return GradientFunTransform(call_func, dyn_vars=dyn_vars, grad_vars=grad_vars, origin_fun=func) - - else: - # has_aux = True: g, aux - # has_aux = False: g - call_func = jax.grad(fun=func, + child_objs = check.is_all_objs(child_objs, out_as='dict') + dyn_vars = check.is_all_vars(dyn_vars, out_as='dict') + + return GradientTransform(target=func, + transform=jax.grad, + grad_vars=grad_vars, + dyn_vars=dyn_vars, + child_objs=child_objs, argnums=argnums, - has_aux=has_aux, - holomorphic=holomorphic, - allow_int=allow_int, - reduce_axes=reduce_axes) - return GradientFunTransform(call_func, dyn_vars=dyn_vars, grad_vars=grad_vars, origin_fun=func) - - else: - # argnums - _argnums, _ = tree_flatten(argnums) - _argnums = tuple(a + 2 for a in _argnums) - if argnums is None and len(grad_vars) == 0: - raise errors.MathError('We detect no require to compute gradients because ' - '"grad_vars" is None and "argnums" is also None. ' - 'Please provide one of them.') - # computation - grad_func = _cls_grad(func=func, - grad_vars=grad_vars, - dyn_vars=dyn_vars, - argnums=(0,) + _argnums, - has_aux=has_aux, - holomorphic=holomorphic, - allow_int=allow_int, - reduce_axes=reduce_axes) - - return GradientTransform(grad_func=grad_func, - grad_tree=grad_tree, - grad_vars=grad_vars, - dyn_vars=dyn_vars, - argnums=argnums, - return_value=return_value, - has_aux=has_aux, - origin_fun=func) + return_value=return_value, + has_aux=False if has_aux is None else has_aux, + transform_setting=dict(holomorphic=holomorphic, + allow_int=allow_int, + reduce_axes=reduce_axes)) def _unravel_array_into_pytree(pytree, axis, arr, is_leaf=None): @@ -501,6 +349,7 @@ def _std_basis(pytree): def _jacrev(fun, argnums=0, holomorphic=False, allow_int=False, has_aux=False, return_value=False): _check_callable(fun) + @wraps(fun) def jacfun(*args, **kwargs): f = linear_util.wrap_init(fun, kwargs) f_partial, dyn_args = argnums_partial(f, argnums, args, require_static_args_hashable=False) @@ -520,46 +369,19 @@ def jacfun(*args, **kwargs): else: return (jac, aux) if has_aux else jac - return GradientFunTransform(jacfun, dyn_vars=(), grad_vars=(), origin_fun=fun) - - -def _cls_jacrev(func, grad_vars, dyn_vars, argnums, - holomorphic=False, allow_int=False, has_aux=False): - # parameters - assert isinstance(dyn_vars, (tuple, list)) # tuple/list of Variable - assert isinstance(grad_vars, (tuple, list)) # tuple/list of Variable - assert isinstance(argnums, (tuple, list)) # tuple/list of int - - # final functions - if has_aux: - @partial(_jacrev, argnums=argnums, holomorphic=holomorphic, - allow_int=allow_int, has_aux=True) - def grad_func(grad_values, dyn_values, *args, **kwargs): - for v, d in zip(dyn_vars, dyn_values): v._value = d - for v, d in zip(grad_vars, grad_values): v._value = d - # outputs: [0] is the value for gradient, - # [1] is other values for return - outputs = func(*args, **kwargs) - output = outputs[0].value if isinstance(outputs[0], Array) else outputs[0] - return output, (outputs, [v.value for v in grad_vars], [v.value for v in dyn_vars]) - - else: - @partial(_jacrev, argnums=argnums, holomorphic=holomorphic, - allow_int=allow_int, has_aux=True) - def grad_func(grad_values, dyn_values, *args, **kwargs): - for v, d in zip(dyn_vars, dyn_values): v._value = d - for v, d in zip(grad_vars, grad_values): v._value = d - outputs = func(*args, **kwargs) - output = outputs.value if isinstance(outputs, Array) else outputs - return output, (outputs, [v.value for v in grad_vars], [v.value for v in dyn_vars]) - - return grad_func + return jacfun def jacrev( - func, grad_vars=None, dyn_vars=None, argnums=None, holomorphic=False, - allow_int=False, has_aux=None, return_value=False, - auto_infer=True + func: Callable, + grad_vars: Optional[Union[Variable, Sequence[Variable], Dict[str, Variable]]] = None, + dyn_vars: Optional[Union[Variable, Sequence[Variable], Dict[str, Variable]]] = None, + child_objs: Optional[Union[BrainPyObject, Sequence[BrainPyObject], Dict[str, BrainPyObject]]] = None, + argnums: Optional[Union[int, Sequence[int]]] = None, + has_aux: Optional[bool] = None, + return_value: bool = False, + holomorphic: bool = False, + allow_int: bool = False, ) -> ObjectTransform: """Extending automatic Jacobian (reverse-mode) of ``func`` to classes. @@ -594,65 +416,44 @@ def jacrev( The dynamically changed variables used in ``func``. grad_vars : optional, ArrayType, sequence of ArrayType, dict The variables in ``func`` to take their gradients. + child_objs: optional, BrainPyObject, sequence, dict + + .. versionadded:: 2.3.1 has_aux: optional, bool Indicates whether ``fun`` returns a pair where the first element is considered the output of the mathematical function to be differentiated and the second element is auxiliary data. Default False. return_value : bool Whether return the loss value. - argnums: Optional, integer or sequence of integers. Specifies which + argnums: Optional, integer or sequence of integers. + Specifies which positional argument(s) to differentiate with respect to (default ``0``). - holomorphic: Optional, bool. Indicates whether ``fun`` is promised to be + holomorphic: Optional, bool. + Indicates whether ``fun`` is promised to be holomorphic. Default False. - allow_int: Optional, bool. Whether to allow differentiating with + allow_int: Optional, bool. + Whether to allow differentiating with respect to integer valued inputs. The gradient of an integer input will have a trivial vector-space dtype (float0). Default False. - auto_infer: bool - Automatically infer all ``Variable`` instance. Returns ------- - fun: ObjectTransform + fun: GradientTransform The transformed object. """ - if dyn_vars is None: - dyn_vars = infer_dyn_vars(func) if auto_infer else dict() - - dyn_vars, grad_vars, grad_tree = _grad_checking(func, dyn_vars, grad_vars) - has_aux = False if has_aux is None else has_aux - - if (len(dyn_vars) == 0) and (len(grad_vars) == 0): - argnums = 0 if argnums is None else argnums - return _jacrev(fun=func, - argnums=argnums, - holomorphic=holomorphic, - allow_int=allow_int, - has_aux=has_aux, - return_value=return_value) - else: - _argnums, _ = tree_flatten(argnums) - _argnums = tuple(a + 2 for a in _argnums) - if argnums is None and len(grad_vars) == 0: - raise errors.MathError('We detect no require to compute gradients because ' - '"grad_vars" is None and "argnums" is also None. ' - 'Please provide one of them.') - # computation - grad_func = _cls_jacrev(func=func, - grad_vars=grad_vars, - dyn_vars=dyn_vars, - argnums=(0,) + _argnums, - has_aux=has_aux, - holomorphic=holomorphic, - allow_int=allow_int) - - return GradientTransform(grad_func=grad_func, - grad_tree=grad_tree, - grad_vars=grad_vars, - dyn_vars=dyn_vars, - argnums=argnums, - return_value=return_value, - has_aux=has_aux, - origin_fun=func) + child_objs = check.is_all_objs(child_objs, out_as='dict') + dyn_vars = check.is_all_vars(dyn_vars, out_as='dict') + + return GradientTransform(target=func, + transform=_jacrev, + grad_vars=grad_vars, + dyn_vars=dyn_vars, + child_objs=child_objs, + argnums=argnums, + return_value=return_value, + has_aux=False if has_aux is None else has_aux, + transform_setting=dict(holomorphic=holomorphic, + allow_int=allow_int)) jacobian = jacrev @@ -664,6 +465,7 @@ def _jacfwd(fun, argnums=0, holomorphic=False, has_aux=False, return_value=False raise NotImplementedError(f'"has_aux" only supported in jax>=0.2.28, but we detect ' f'the current jax version is {jax.__version__}') + @wraps(fun) def jacfun(*args, **kwargs): f = linear_util.wrap_init(fun, kwargs) f_partial, dyn_args = argnums_partial(f, argnums, args, require_static_args_hashable=False) @@ -682,44 +484,18 @@ def jacfun(*args, **kwargs): else: return (jac, aux) if has_aux else jac - return GradientFunTransform(jacfun, dyn_vars=(), grad_vars=(), origin_fun=fun) - - -def _cls_jacfwd(func, grad_vars, dyn_vars, argnums, holomorphic=False, has_aux=False): - # parameters - assert isinstance(dyn_vars, (tuple, list)) # tuple/list of Variable - assert isinstance(grad_vars, (tuple, list)) # tuple/list of Variable - assert isinstance(argnums, (tuple, list)) # tuple/list of int - - # final functions - if has_aux: - @partial(_jacfwd, - argnums=argnums, holomorphic=holomorphic, has_aux=True) - def grad_func(grad_values, dyn_values, *args, **kwargs): - for v, d in zip(dyn_vars, dyn_values): v._value = d - for v, d in zip(grad_vars, grad_values): v._value = d - # outputs: [0] is the value for gradient, - # [1] is other values for return - outputs = func(*args, **kwargs) - output = outputs[0].value if isinstance(outputs[0], Array) else outputs[0] - return output, (outputs, [v.value for v in grad_vars], [v.value for v in dyn_vars]) - - else: - @partial(_jacfwd, - argnums=argnums, holomorphic=holomorphic, has_aux=True) - def grad_func(grad_values, dyn_values, *args, **kwargs): - for v, d in zip(dyn_vars, dyn_values): v._value = d - for v, d in zip(grad_vars, grad_values): v._value = d - outputs = func(*args, **kwargs) - output = outputs.value if isinstance(outputs, Array) else outputs - return output, (outputs, [v.value for v in grad_vars], [v.value for v in dyn_vars]) - - return grad_func + return jacfun def jacfwd( - func, grad_vars=None, dyn_vars=None, argnums=None, holomorphic=False, - has_aux=None, return_value=False, auto_infer=True + func: Callable, + grad_vars: Optional[Union[Variable, Sequence[Variable], Dict[str, Variable]]] = None, + dyn_vars: Optional[Union[Variable, Sequence[Variable], Dict[str, Variable]]] = None, + child_objs: Optional[Union[BrainPyObject, Sequence[BrainPyObject], Dict[str, BrainPyObject]]] = None, + argnums: Optional[Union[int, Sequence[int]]] = None, + has_aux: Optional[bool] = None, + return_value: bool = False, + holomorphic: bool = False, ) -> ObjectTransform: """Extending automatic Jacobian (forward-mode) of ``func`` to classes. @@ -753,6 +529,9 @@ def jacfwd( The dynamically changed variables used in ``func``. grad_vars : optional, ArrayType, sequence of ArrayType, dict The variables in ``func`` to take their gradients. + child_objs: Optional[Union[BrainPyObject, Sequence[BrainPyObject], Dict[str, BrainPyObject]]] = None, + + .. versionadded:: 2.3.1 has_aux: optional, bool Indicates whether ``fun`` returns a pair where the first element is considered the output of the mathematical function to be @@ -763,55 +542,34 @@ def jacfwd( positional argument(s) to differentiate with respect to (default ``0``). holomorphic: Optional, bool. Indicates whether ``fun`` is promised to be holomorphic. Default False. - auto_infer: bool - Automatically infer all ``Variable`` instance. Returns ------- - obj: ObjectTransform + obj: GradientTransform The transformed object. """ - if dyn_vars is None: - dyn_vars = infer_dyn_vars(func) if auto_infer else dict() - - dyn_vars, grad_vars, grad_tree = _grad_checking(func, dyn_vars, grad_vars) - has_aux = False if has_aux is None else has_aux - - if (len(dyn_vars) == 0) and (len(grad_vars) == 0): - argnums = 0 if argnums is None else argnums - return _jacfwd(fun=func, - argnums=argnums, - holomorphic=holomorphic, - has_aux=has_aux, - return_value=return_value) - else: - _argnums, _ = tree_flatten(argnums) - _argnums = tuple(a + 2 for a in _argnums) - if argnums is None and len(grad_vars) == 0: - raise errors.MathError('We detect no require to compute gradients because ' - '"grad_vars" is None and "argnums" is also None. ' - 'Please provide one of them.') - # computation - grad_func = _cls_jacfwd(func=func, - grad_vars=grad_vars, - dyn_vars=dyn_vars, - argnums=(0,) + _argnums, - has_aux=has_aux, - holomorphic=holomorphic) - - return GradientTransform(grad_func=grad_func, - grad_tree=grad_tree, - grad_vars=grad_vars, - dyn_vars=dyn_vars, - argnums=argnums, - return_value=return_value, - has_aux=has_aux, - origin_fun=func) + child_objs = check.is_all_objs(child_objs, out_as='dict') + dyn_vars = check.is_all_vars(dyn_vars, out_as='dict') + + return GradientTransform(target=func, + transform=_jacfwd, + grad_vars=grad_vars, + dyn_vars=dyn_vars, + child_objs=child_objs, + argnums=argnums, + return_value=return_value, + has_aux=False if has_aux is None else has_aux, + transform_setting=dict(holomorphic=holomorphic)) def hessian( - func, grad_vars=None, dyn_vars=None, argnums=None, - holomorphic=False, return_value=False, auto_infer=True + func: Callable, + grad_vars: Optional[Union[Variable, Sequence[Variable], Dict[str, Variable]]] = None, + dyn_vars: Optional[Union[Variable, Sequence[Variable], Dict[str, Variable]]] = None, + child_objs: Optional[Union[BrainPyObject, Sequence[BrainPyObject], Dict[str, BrainPyObject]]] = None, + argnums: Optional[Union[int, Sequence[int]]] = None, + return_value: bool = False, + holomorphic=False, ) -> ObjectTransform: """Hessian of ``func`` as a dense array. @@ -824,6 +582,9 @@ def hessian( containers thereof. dyn_vars : optional, ArrayCollector, sequence of ArrayType The dynamical changed variables. + child_objs: optional, BrainPyObject, sequnce, dict + + .. versionadded:: 2.3.1 grad_vars : optional, ArrayCollector, sequence of ArrayType The variables required to compute their gradients. argnums: Optional, integer or sequence of integers @@ -832,39 +593,33 @@ def hessian( Indicates whether ``fun`` is promised to be holomorphic. Default False. return_value : bool Whether return the hessian values. - auto_infer: bool - Automatically infer all ``Variable`` instance. Returns ------- obj: ObjectTransform The transformed object. """ - if dyn_vars is None: - dyn_vars = infer_dyn_vars(func) if auto_infer else dict() - - dyn_vars, grad_vars, grad_tree = _grad_checking(func, dyn_vars, grad_vars) - argnums = 0 if argnums is None else argnums - - if (len(dyn_vars) == 0) and (len(grad_vars) == 0) and (not return_value): - f = jax.hessian(func, argnums=argnums, holomorphic=holomorphic) - return GradientFunTransform(f, dyn_vars=(), grad_vars=(), origin_fun=func) - else: - return jacfwd(jacrev(func, - dyn_vars=dyn_vars, - grad_vars=grad_vars, - argnums=argnums, - holomorphic=holomorphic), - dyn_vars=dyn_vars, - grad_vars=grad_vars, - argnums=argnums, - holomorphic=holomorphic, - return_value=return_value) + child_objs = check.is_all_objs(child_objs, out_as='dict') + dyn_vars = check.is_all_vars(dyn_vars, out_as='dict') + + return jacfwd(jacrev(func, + dyn_vars=dyn_vars, + child_objs=child_objs, + grad_vars=grad_vars, + argnums=argnums, + holomorphic=holomorphic), + dyn_vars=dyn_vars, + child_objs=child_objs, + grad_vars=grad_vars, + argnums=argnums, + holomorphic=holomorphic, + return_value=return_value) def _vector_grad(func, argnums=0, return_value=False, has_aux=False): _check_callable(func) + @wraps(func) def grad_fun(*args, **kwargs): f = linear_util.wrap_init(func, kwargs) f_partial, dyn_args = argnums_partial(f, argnums, args, require_static_args_hashable=False) @@ -882,40 +637,17 @@ def grad_fun(*args, **kwargs): else: return (grads, y) if return_value else grads - return GradientFunTransform(grad_fun, (), (), origin_fun=func) - - -def _cls_vector_grad(func, grad_vars, dyn_vars, argnums, has_aux=False): - # parameters - assert isinstance(dyn_vars, (tuple, list)) # tuple/list of Variable - assert isinstance(grad_vars, (tuple, list)) # tuple/list of Variable - assert isinstance(argnums, (tuple, list)) # tuple/list of int - - # final functions - if has_aux: - @partial(_vector_grad, argnums=argnums, has_aux=True) - def grad_func(grad_values, dyn_values, *args, **kwargs): - for v, d in zip(dyn_vars, dyn_values): v._value = d - for v, d in zip(grad_vars, grad_values): v._value = d - outputs = func(*args, **kwargs) - output = outputs[0].value if isinstance(outputs[0], Array) else outputs[0] - return output, (outputs, [v.value for v in grad_vars], [v.value for v in dyn_vars]) - - else: - @partial(_vector_grad, argnums=argnums, has_aux=True) - def grad_func(grad_values, dyn_values, *args, **kwargs): - for v, d in zip(dyn_vars, dyn_values): v._value = d - for v, d in zip(grad_vars, grad_values): v._value = d - outputs = func(*args, **kwargs) - output = outputs.value if isinstance(outputs, Array) else outputs - return output, (outputs, [v.value for v in grad_vars], [v.value for v in dyn_vars]) - - return grad_func + return grad_fun def vector_grad( - func, dyn_vars=None, grad_vars=None, argnums=None, - return_value=False, has_aux=None, auto_infer=True + func: Callable, + grad_vars: Optional[Union[Variable, Sequence[Variable], Dict[str, Variable]]] = None, + dyn_vars: Optional[Union[Variable, Sequence[Variable], Dict[str, Variable]]] = None, + child_objs: Optional[Union[BrainPyObject, Sequence[BrainPyObject], Dict[str, BrainPyObject]]] = None, + argnums: Optional[Union[int, Sequence[int]]] = None, + return_value: bool = False, + has_aux: Optional[bool] = None, ) -> ObjectTransform: """Take vector-valued gradients for function ``func``. @@ -943,9 +675,13 @@ def vector_grad( Parameters ---------- - func: Function whose Jacobian is to be computed. + func: Callable + Function whose gradient is to be computed. dyn_vars : optional, ArrayType, sequence of ArrayType, dict The dynamically changed variables used in ``func``. + child_objs: optional, BrainPyObject, sequnce, dict + + .. versionadded:: 2.3.1 grad_vars : optional, ArrayType, sequence of ArrayType, dict The variables in ``func`` to take their gradients. has_aux: optional, bool @@ -956,46 +692,20 @@ def vector_grad( Whether return the loss value. argnums: Optional, integer or sequence of integers. Specifies which positional argument(s) to differentiate with respect to (default ``0``). - auto_infer: bool - Automatically infer all ``Variable`` instance. Returns ------- - func : ObjectTransform + func : GradientTransform The vector gradient function. """ - if dyn_vars is None: - dyn_vars = infer_dyn_vars(func) if auto_infer else dict() - - dyn_vars, grad_vars, grad_tree = _grad_checking(func, dyn_vars, grad_vars) - has_aux = False if has_aux is None else has_aux - - if (len(dyn_vars) == 0) and (len(grad_vars) == 0): - argnums = 0 if argnums is None else argnums - return _vector_grad(func=func, - argnums=argnums, - return_value=return_value, - has_aux=has_aux) - - else: - _argnums, _ = tree_flatten(argnums) - _argnums = tuple(a + 2 for a in _argnums) - if argnums is None and len(grad_vars) == 0: - raise errors.MathError('We detect no require to compute gradients because ' - '"grad_vars" is None and "argnums" is also None. ' - 'Please provide one of them.') - # computation - grad_func = _cls_vector_grad(func=func, - grad_vars=grad_vars, - dyn_vars=dyn_vars, - argnums=(0,) + _argnums, - has_aux=has_aux) - - return GradientTransform(grad_func=grad_func, - grad_tree=grad_tree, - grad_vars=grad_vars, - dyn_vars=dyn_vars, - argnums=argnums, - return_value=return_value, - has_aux=has_aux, - origin_fun=func) + child_objs = check.is_all_objs(child_objs, out_as='dict') + dyn_vars = check.is_all_vars(dyn_vars, out_as='dict') + + return GradientTransform(target=func, + transform=_vector_grad, + grad_vars=grad_vars, + dyn_vars=dyn_vars, + child_objs=child_objs, + argnums=argnums, + return_value=return_value, + has_aux=False if has_aux is None else has_aux) diff --git a/brainpy/math/object_transform/controls.py b/brainpy/math/object_transform/controls.py index b32a1ca30..8578efbbe 100644 --- a/brainpy/math/object_transform/controls.py +++ b/brainpy/math/object_transform/controls.py @@ -36,18 +36,20 @@ class ControlObject(ObjectTransform): + """Object-oriented Control Flow Transformation in BrainPy. + """ def __init__( self, call: Callable, - dyn_vars, + dyn_vars: Union[Variable, Sequence[Variable], Dict[str, Variable]], repr_fun: Dict, name=None ): super().__init__(name=name) + self.register_implicit_vars(dyn_vars) self._f = call self._dyn_vars = dyn_vars - self.register_implicit_vars(dyn_vars) self._repr_fun = repr_fun def __call__(self, *args, **kwargs): @@ -117,10 +119,10 @@ def fun2scan(dyn_values, x): def make_loop( - body_fun, - dyn_vars, - out_vars=None, - has_return=False + body_fun: Callable, + dyn_vars: Union[Variable, Sequence[Variable], Dict[str, Variable]], + out_vars: Union[Variable, Sequence[Variable], Dict[str, Variable]]=None, + has_return: bool =False ) -> ControlObject: """Make a for-loop function, which iterate over inputs. @@ -204,7 +206,8 @@ def call(xs=None, length=None): try: add_context(name) dyn_values, (out_values, results) = lax.scan( - f=fun2scan, init=init_values, xs=xs, length=length) + f=fun2scan, init=init_values, xs=xs, length=length + ) del_context(name) except UnexpectedTracerError as e: del_context(name) @@ -318,7 +321,10 @@ def call(x=None): raise e for v, d in zip(dyn_vars, dyn_values): v._value = d - return ControlObject(call, dyn_vars, repr_fun={'cond_fun': cond_fun, 'body_fun': body_fun}) + return ControlObject(call=call, + dyn_vars=dyn_vars, + repr_fun={'cond_fun': cond_fun, 'body_fun': body_fun}, + name=name) def make_cond( From c314df881c9a7821005605906cd86e0133da9ad5 Mon Sep 17 00:00:00 2001 From: chaoming Date: Mon, 26 Dec 2022 21:30:14 +0800 Subject: [PATCH 2/6] fix tests --- .../tests/test_compat.py} | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) rename brainpy/math/{operators/tests/test_differential_spike.py => surrogate/tests/test_compat.py} (89%) diff --git a/brainpy/math/operators/tests/test_differential_spike.py b/brainpy/math/surrogate/tests/test_compat.py similarity index 89% rename from brainpy/math/operators/tests/test_differential_spike.py rename to brainpy/math/surrogate/tests/test_compat.py index 3f071222e..17d8a33a6 100644 --- a/brainpy/math/operators/tests/test_differential_spike.py +++ b/brainpy/math/surrogate/tests/test_compat.py @@ -30,16 +30,16 @@ def f5(a, b): self.f5 = f5 def test_sp_sigmoid_grad2(self): - a = bm.ones(10) * 2 - b = bm.ones(10) + a = bm.ones(10).value * 2 + b = bm.ones(10).value grad1, val1 = self.f4(a, b) grad2, val2 = self.f5(a, b) self.assertTrue(bm.array_equal(grad1, grad2)) self.assertTrue(bm.array_equal(val1, val2)) def test_sp_sigmoid_grad1(self): - a = bm.zeros(10) - b = bm.ones(10) + a = bm.zeros(10).value + b = bm.ones(10).value grad1, val1 = self.f4(a, b) grad2, val2 = self.f5(a, b) print(grad2) @@ -49,8 +49,8 @@ def test_sp_sigmoid_grad1(self): self.assertTrue(~bm.array_equal(val1, val2)) def test_sp_sigmoid_grad3(self): - a = bm.ones(10) * -2 - b = bm.ones(10) + a = bm.ones(10).value * -2 + b = bm.ones(10).value grad1, val1 = self.f4(a, b) grad2, val2 = self.f5(a, b) self.assertTrue(bm.array_equal(grad1, grad2)) From 0690fac648e6cb5ff67bbf6138eba62dcccc0d14 Mon Sep 17 00:00:00 2001 From: chaoming Date: Tue, 27 Dec 2022 14:54:42 +0800 Subject: [PATCH 3/6] support calling a function after each report in `BPTrainer` --- brainpy/train/back_propagation.py | 81 ++++++++++++++----------------- brainpy/train/base.py | 8 ++- 2 files changed, 43 insertions(+), 46 deletions(-) diff --git a/brainpy/train/back_propagation.py b/brainpy/train/back_propagation.py index 2c159484a..f3e334596 100644 --- a/brainpy/train/back_propagation.py +++ b/brainpy/train/back_propagation.py @@ -2,13 +2,12 @@ import sys import time -import warnings from collections.abc import Iterable from functools import partial -from typing import Union, Dict, Callable, Sequence, Any +from typing import Union, Dict, Callable, Sequence, Any, Optional -import numpy as np import jax.numpy as jnp +import numpy as np from jax.tree_util import tree_map import brainpy.losses as losses @@ -55,7 +54,6 @@ class BPTrainer(DSTrainer): Make the monitored results as NumPy arrays. logger: Any A file-like object (stream); defaults to the current `sys.stdout`. - shuffle_data: bool .. deprecated:: 2.2.4.1 Control the data shuffling by user self. @@ -171,51 +169,21 @@ def train_losses(self): def test_losses(self): return self.get_hist_metric(phase='test') - def predict( - self, - inputs: Union[ArrayType, Sequence[ArrayType], Dict[str, ArrayType]], - reset_state: bool = True, - shared_args: Dict = None, - eval_time: bool = False - ) -> Output: - """Predict a series of input data with the given target model. - - This function use the JIT compilation to accelerate the model simulation. - Moreover, it can automatically monitor the node variables, states, inputs, - feedbacks and its output, if users want. - - Parameters - ---------- - inputs: ArrayType, sequence, dict - The feedforward input data. It must be a 3-dimensional data - which has the shape of `(num_sample, num_time, num_feature)`. - shared_args: dict - Shared keyword arguments for the given target model. - reset_state: bool - Whether reset the model states. Default True. - eval_time: bool - Whether evaluate the running time or not. Default False. - """ - return super().predict(inputs=inputs, - reset_state=reset_state, - shared_args=shared_args, - eval_time=eval_time) - def fit( self, train_data: Union[Callable, Iterable], - test_data: Union[Callable, Iterable] = None, + test_data: Optional[Union[Callable, Iterable]] = None, num_epoch: int = 100, num_report: int = -1, reset_state: bool = True, - shared_args: Dict = None, + shared_args: Optional[Dict] = None, + fun_after_report: Optional[Callable] = None, # ------ # API deprecated batch_size: int = None, ): - """ - Fit the target model according to the given training data. + """Fit the target model according to the given training data. Parameters ---------- @@ -233,23 +201,27 @@ def fit( then we will only fit the model with the only last output. - If the shape of each tensor is `(num_sample, num_time, num_feature)`, then the fitting happens on the whole data series. - test_data: callable, iterable, optional Same as ``train_data``. - num_epoch: int The number of training epoch. Default 100. - num_report: int The number of step to report the progress. If `num_report=-1`, it will report the training progress each epoch. - reset_state: bool Whether reset the initial states of the target model. - shared_args: dict The shared keyword arguments for the target models. - + fun_after_report: optional, Callable + The function to call after each report of `fit` phase or `test` phase. + The function should receive three arguments: + - ``idx`` for the indicator the current the running index. (If ``report=-1``, + The running index is the epoch. Otherwise, is the 'fit_idx' for 'fit' phase + and 'test_idx' for 'test' phase). + - ``metrics``: the metrics defined in the loss function + - ``phase``: to indicate the phase of 'fit' or 'test'. + + .. versionadded:: 2.3.1 batch_size: int .. deprecated:: 2.2.4.1 @@ -264,6 +236,16 @@ def fit( if len(train_data) == 2: raise UnsupportedError(msg) + if fun_after_report is not None: + assert callable(fun_after_report), ('\n' + 'Unknown "fun_after_report", ' + 'it should be a callable function receiving ' + 'three arguments: idx, metrics, phase') + + if shared_args is None: + shared_args = dict() + shared_args['fit'] = shared_args.get('fit', False) + true_progress_bar = self.progress_bar self.progress_bar = False @@ -316,6 +298,8 @@ def fit( print((f'Train {fit_i} steps, use {fit_t + fit_t1 - fit_t0:.4f} s' + ', {}'.format(", ".join([f"{k} {v}" for k, v in aux.items()]))), file=self.logger) + if fun_after_report is not None: + fun_after_report(fit_i, aux, 'fit') fit_t0 = time.time() fit_t = 0 @@ -333,6 +317,8 @@ def fit( print((f'Train {epoch_idx} epoch, use {fit_t1 - fit_t0:.4f} s' + ', {}'.format(", ".join([f"{k} {v}" for k, v in aux.items()]))), file=self.logger) + if fun_after_report is not None: + fun_after_report(epoch_idx, aux, 'fit') else: fit_t = time.time() - fit_t0 @@ -377,6 +363,8 @@ def fit( print((f'Test {test_i} steps, use {test_t + test_t1 - test_t0:.4f} s' + ', {}'.format(", ".join([f"{k} {v}" for k, v in aux.items()]))), file=self.logger) + if fun_after_report is not None: + fun_after_report(test_i, aux, 'test') test_t0 = time.time() test_t = 0 @@ -394,6 +382,8 @@ def fit( print((f'Test {epoch_idx} epoch, use {test_t1 - test_t0:.4f} s' + ', {}'.format(", ".join([f"{k} {v}" for k, v in aux.items()]))), file=self.logger) + if fun_after_report is not None: + fun_after_report(epoch_idx, aux, 'test') else: test_t = time.time() - test_t0 @@ -609,6 +599,9 @@ def predict( output: ArrayType, dict The model output. """ + if shared_args is None: shared_args = dict() + shared_args['fit'] = shared_args.get('fit', False) + # reset the model states if reset_state: self.target.reset_state(self._get_input_batch_size(xs=inputs)) diff --git a/brainpy/train/base.py b/brainpy/train/base.py index 55c6392f2..4985be181 100644 --- a/brainpy/train/base.py +++ b/brainpy/train/base.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- -from typing import Dict, Sequence, Any, Union +from typing import Dict, Sequence, Any, Union, Optional import brainpy.math as bm from brainpy.dyn.base import DynamicalSystem @@ -67,7 +67,7 @@ def predict( self, inputs: Union[ArrayType, Sequence[ArrayType], Dict[str, ArrayType]], reset_state: bool = False, - shared_args: Dict = None, + shared_args: Optional[Dict] = None, eval_time: bool = False ) -> Output: """Prediction function. @@ -88,6 +88,10 @@ def predict( output: ArrayType, sequence of ArrayType, dict of ArrayType The running output. """ + if shared_args is None: + shared_args = dict() + shared_args['fit'] = shared_args.get('fit', False) + return super().predict(inputs=inputs, reset_state=reset_state, shared_args=shared_args, From 0c5342833e12f75e5c1f1f5f3cd1372fa6c540ce Mon Sep 17 00:00:00 2001 From: chaoming Date: Tue, 27 Dec 2022 14:56:00 +0800 Subject: [PATCH 4/6] updates --- brainpy/base/function.py | 9 +- brainpy/checkpoints.py | 2 +- brainpy/math/fft.py | 82 +- brainpy/math/linalg.py | 100 ++- brainpy/math/numpy_ops.py | 758 +++++++++--------- brainpy/math/object_transform/controls.py | 3 +- brainpy/math/random.py | 12 +- .../{tests/test_checking.py => test_check.py} | 0 brainpy/test_checkpoints.py | 1 + brainpy/types.py | 5 +- 10 files changed, 490 insertions(+), 482 deletions(-) rename brainpy/{tests/test_checking.py => test_check.py} (100%) create mode 100644 brainpy/test_checkpoints.py diff --git a/brainpy/base/function.py b/brainpy/base/function.py index bac11da15..0eedafc76 100644 --- a/brainpy/base/function.py +++ b/brainpy/base/function.py @@ -1,9 +1,12 @@ # -*- coding: utf-8 -*- -from typing import Callable, Sequence, Dict, Union +from typing import Callable, Sequence, Dict, Union, TypeVar from brainpy.base.base import BrainPyObject -from brainpy.types import ArrayType + + +Variable = TypeVar('Variable') + __all__ = [ 'FunAsObject', @@ -28,7 +31,7 @@ class FunAsObject(BrainPyObject): def __init__(self, f: Callable, child_objs: Union[BrainPyObject, Sequence[BrainPyObject], Dict[dict, BrainPyObject]] = None, - dyn_vars: Union[ArrayType, Sequence[ArrayType], Dict[dict, ArrayType]] = None, + dyn_vars: Union[Variable, Sequence[Variable], Dict[dict, Variable]] = None, name: str = None): super(FunAsObject, self).__init__(name=name) self._f = f diff --git a/brainpy/checkpoints.py b/brainpy/checkpoints.py index 8e9a2e074..47a474566 100644 --- a/brainpy/checkpoints.py +++ b/brainpy/checkpoints.py @@ -1304,7 +1304,7 @@ def load( gda_manager: Optional[Any] = None, allow_partial_mpa_restoration: bool = False, ) -> PyTree: - """Load last or best checkpoint from the given checkpoint path. + """Load last or best checkpoint from the given checkpoint path. Sorts the checkpoint files naturally, returning the highest-valued file, e.g.: diff --git a/brainpy/math/fft.py b/brainpy/math/fft.py index 426a33345..106ba8efe 100644 --- a/brainpy/math/fft.py +++ b/brainpy/math/fft.py @@ -1,9 +1,9 @@ # -*- coding: utf-8 -*- +from typing import Optional import jax.numpy.fft -from brainpy.math.ndarray import Array -from brainpy.math.numpy_ops import _remove_brainpy_array +from brainpy.math.numpy_ops import _as_jax_array_ __all__ = [ "fft", "fft2", "fftfreq", "fftn", "fftshift", "hfft", @@ -12,89 +12,95 @@ ] -def fft(a, n=None, axis=-1, norm=None): - a = _remove_brainpy_array(a) - return Array(jax.numpy.fft.fft(a=a, n=n, axis=axis, norm=norm)) +def fft(a, + n: Optional[int] = None, + axis: int = -1, + norm: Optional[str] = None): + a = _as_jax_array_(a) + return jax.numpy.fft.fft(a=a, n=n, axis=axis, norm=norm) def fft2(a, s=None, axes=(-2, -1), norm=None): - a = _remove_brainpy_array(a) - return Array(jax.numpy.fft.fft2(a=a, s=s, axes=axes, norm=norm)) + a = _as_jax_array_(a) + return jax.numpy.fft.fft2(a=a, s=s, axes=axes, norm=norm) def fftfreq(n, d=1.0): - return Array(jax.numpy.fft.fftfreq(n=n, d=d)) + return jax.numpy.fft.fftfreq(n=n, d=d) def fftn(a, s=None, axes=None, norm=None): - a = _remove_brainpy_array(a) - return Array(jax.numpy.fft.fftn(a=a, s=s, axes=axes, norm=norm)) + a = _as_jax_array_(a) + return jax.numpy.fft.fftn(a=a, s=s, axes=axes, norm=norm) def fftshift(x, axes=None): - x = _remove_brainpy_array(x) - return Array(jax.numpy.fft.fftshift(x=x, axes=axes)) + x = _as_jax_array_(x) + return jax.numpy.fft.fftshift(x=x, axes=axes) def hfft(a, n=None, axis=-1, norm=None): - a = _remove_brainpy_array(a) - return Array(jax.numpy.fft.hfft(a=a, n=n, axis=axis, norm=norm)) + a = _as_jax_array_(a) + return jax.numpy.fft.hfft(a=a, n=n, axis=axis, norm=norm) -def ifft(a, n=None, axis=-1, norm=None): - a = _remove_brainpy_array(a) - return Array(jax.numpy.fft.ifft(a=a, n=n, axis=axis, norm=norm)) +def ifft(a, + n: Optional[int] = None, + axis: int = -1, + norm: Optional[str] = None): + a = _as_jax_array_(a) + return jax.numpy.fft.ifft(a=a, n=n, axis=axis, norm=norm) def ifft2(a, s=None, axes=(-2, -1), norm=None): - a = _remove_brainpy_array(a) - return Array(jax.numpy.fft.ifft2(a=a, s=s, axes=axes, norm=norm)) + a = _as_jax_array_(a) + return jax.numpy.fft.ifft2(a=a, s=s, axes=axes, norm=norm) def ifftn(a, s=None, axes=None, norm=None): - a = _remove_brainpy_array(a) - return Array(jax.numpy.fft.ifftn(a=a, s=s, axes=axes, norm=norm)) + a = _as_jax_array_(a) + return jax.numpy.fft.ifftn(a=a, s=s, axes=axes, norm=norm) def ifftshift(x, axes=None): - x = _remove_brainpy_array(x) - return Array(jax.numpy.fft.ifftshift(x=x, axes=axes)) + x = _as_jax_array_(x) + return jax.numpy.fft.ifftshift(x=x, axes=axes) def ihfft(a, n=None, axis=-1, norm=None): - a = _remove_brainpy_array(a) - return Array(jax.numpy.fft.ihfft(a=a, n=n, axis=axis, norm=norm)) + a = _as_jax_array_(a) + return jax.numpy.fft.ihfft(a=a, n=n, axis=axis, norm=norm) def irfft(a, n=None, axis=-1, norm=None): - a = _remove_brainpy_array(a) - return Array(jax.numpy.fft.irfft(a=a, n=n, axis=axis, norm=norm)) + a = _as_jax_array_(a) + return jax.numpy.fft.irfft(a=a, n=n, axis=axis, norm=norm) def irfft2(a, s=None, axes=(-2, -1), norm=None): - a = _remove_brainpy_array(a) - return Array(jax.numpy.fft.irfft2(a=a, s=s, axes=axes, norm=norm)) + a = _as_jax_array_(a) + return jax.numpy.fft.irfft2(a=a, s=s, axes=axes, norm=norm) def irfftn(a, s=None, axes=None, norm=None): - a = _remove_brainpy_array(a) - return Array(jax.numpy.fft.irfftn(a=a, s=s, axes=axes, norm=norm)) + a = _as_jax_array_(a) + return jax.numpy.fft.irfftn(a=a, s=s, axes=axes, norm=norm) def rfft(a, n=None, axis=-1, norm=None): - a = _remove_brainpy_array(a) - return Array(jax.numpy.fft.rfft(a=a, n=n, axis=axis, norm=norm)) + a = _as_jax_array_(a) + return jax.numpy.fft.rfft(a=a, n=n, axis=axis, norm=norm) def rfft2(a, s=None, axes=(-2, -1), norm=None): - a = _remove_brainpy_array(a) - return Array(jax.numpy.fft.rfft2(a=a, s=s, axes=axes, norm=norm)) + a = _as_jax_array_(a) + return jax.numpy.fft.rfft2(a=a, s=s, axes=axes, norm=norm) def rfftfreq(n, d=1.0): - return Array(jax.numpy.fft.rfftfreq(n=n, d=d)) + return jax.numpy.fft.rfftfreq(n=n, d=d) def rfftn(a, s=None, axes=None, norm=None): - a = _remove_brainpy_array(a) - return Array(jax.numpy.fft.rfftn(a=a, s=s, axes=axes, norm=norm)) + a = _as_jax_array_(a) + return jax.numpy.fft.rfftn(a=a, s=s, axes=axes, norm=norm) diff --git a/brainpy/math/linalg.py b/brainpy/math/linalg.py index f2b25ddc9..ac96329e8 100644 --- a/brainpy/math/linalg.py +++ b/brainpy/math/linalg.py @@ -3,7 +3,7 @@ from jax.numpy import linalg from brainpy.math.ndarray import Array -from brainpy.math.numpy_ops import _remove_brainpy_array +from brainpy.math.numpy_ops import _as_jax_array_ __all__ = [ 'cholesky', 'cond', 'det', 'eig', 'eigh', 'eigvals', 'eigvalsh', 'inv', 'svd', @@ -13,115 +13,107 @@ def cholesky(a): - a = _remove_brainpy_array(a) - return Array(linalg.cholesky(a)) + a = _as_jax_array_(a) + return linalg.cholesky(a) def cond(x, p=None): - x = _remove_brainpy_array(x) - p = _remove_brainpy_array(p) + x = _as_jax_array_(x) + p = _as_jax_array_(p) return linalg.cond(x, p=p) def det(a): - a = _remove_brainpy_array(a) - return Array(linalg.det(a)) + a = _as_jax_array_(a) + return linalg.det(a) def eig(a): - a = _remove_brainpy_array(a) + a = _as_jax_array_(a) w, v = linalg.eig(a) - return Array(w), Array(v) + return w, v def eigh(a, UPLO='L'): - a = _remove_brainpy_array(a) + a = _as_jax_array_(a) w, v = linalg.eigh(a, UPLO=UPLO) - return Array(w), Array(v) + return w, v def eigvals(a): - a = _remove_brainpy_array(a) - return Array(linalg.eigvals(a)) + a = _as_jax_array_(a) + return linalg.eigvals(a) def eigvalsh(a, UPLO='L'): - a = _remove_brainpy_array(a) - return Array(linalg.eigvalsh(a, UPLO=UPLO)) + a = _as_jax_array_(a) + return linalg.eigvalsh(a, UPLO=UPLO) def inv(a): - a = _remove_brainpy_array(a) - return Array(linalg.inv(a)) + a = _as_jax_array_(a) + return linalg.inv(a) def svd(a, full_matrices=True, compute_uv=True): - a = _remove_brainpy_array(a) - u, s, vh = linalg.svd(a, full_matrices=full_matrices, compute_uv=compute_uv) - return Array(u), Array(s), Array(vh) + a = _as_jax_array_(a) + return linalg.svd(a, full_matrices=full_matrices, compute_uv=compute_uv) def lstsq(a, b, rcond=None): - a = _remove_brainpy_array(a) - b = _remove_brainpy_array(b) - x, resid, rank, s = linalg.lstsq(a, b, rcond=rcond) - return Array(x), Array(resid), rank, Array(s) + a = _as_jax_array_(a) + b = _as_jax_array_(b) + return linalg.lstsq(a, b, rcond=rcond) def matrix_power(a, n): - a = _remove_brainpy_array(a) - return Array(linalg.matrix_power(a, n)) + a = _as_jax_array_(a) + return linalg.matrix_power(a, n) def matrix_rank(M, tol=None): - M = _remove_brainpy_array(M) - r = linalg.matrix_rank(M, tol=tol) - return r if isinstance(r, int) else Array(r) + M = _as_jax_array_(M) + return linalg.matrix_rank(M, tol=tol) def norm(x, ord=None, axis=None, keepdims=False): - x = _remove_brainpy_array(x) - r = linalg.norm(x, ord=ord, axis=axis, keepdims=keepdims) - return r if axis is None else Array(r) + x = _as_jax_array_(x) + return linalg.norm(x, ord=ord, axis=axis, keepdims=keepdims) def pinv(a, rcond=None): - a = _remove_brainpy_array(a) - rcond = _remove_brainpy_array(rcond) - return Array(linalg.pinv(a, rcond=rcond)) + a = _as_jax_array_(a) + rcond = _as_jax_array_(rcond) + return linalg.pinv(a, rcond=rcond) def qr(a, mode="reduced"): - a = _remove_brainpy_array(a) - r = linalg.qr(a, mode=mode) - if isinstance(r, (tuple, list)): - return Array(r[0]), Array(r[1]) - else: - return Array(r) + a = _as_jax_array_(a) + return linalg.qr(a, mode=mode) def solve(a, b): - a = _remove_brainpy_array(a) - b = _remove_brainpy_array(b) - return Array(linalg.solve(a, b)) + a = _as_jax_array_(a) + b = _as_jax_array_(b) + return linalg.solve(a, b) def slogdet(a): - a = _remove_brainpy_array(a) - return Array(linalg.slogdet(a)) + a = _as_jax_array_(a) + return linalg.slogdet(a) def tensorinv(a, ind=2): - a = _remove_brainpy_array(a) - return Array(linalg.tensorinv(a, ind=ind)) + a = _as_jax_array_(a) + return linalg.tensorinv(a, ind=ind) def tensorsolve(a, b, axes=None): - a = _remove_brainpy_array(a) - b = _remove_brainpy_array(b) - return Array(linalg.tensorsolve(a, b, axes=axes)) + a = _as_jax_array_(a) + b = _as_jax_array_(b) + return linalg.tensorsolve(a, b, axes=axes) -def multi_dot(arrays): - arrays = [_remove_brainpy_array(a) for a in arrays] - return Array(linalg.multi_dot(arrays)) +def multi_dot(arrays, *, precision=None): + arrays = [_as_jax_array_(a) for a in arrays] + return linalg.multi_dot(arrays, precision=precision) diff --git a/brainpy/math/numpy_ops.py b/brainpy/math/numpy_ops.py index 3aebfb90c..adb0a3290 100644 --- a/brainpy/math/numpy_ops.py +++ b/brainpy/math/numpy_ops.py @@ -209,7 +209,7 @@ def as_variable(tensor, dtype=None): return Variable(asarray(tensor, dtype=dtype)) -def _remove_brainpy_array(obj): +def _as_jax_array_(obj): return obj.value if isinstance(obj, Array) else obj @@ -220,29 +220,29 @@ def clip_by_norm(t, clip_norm, axis=None): @wraps(jnp.delete) def delete(arr, obj, axis=None): - arr = _remove_brainpy_array(arr) - obj = _remove_brainpy_array(obj) + arr = _as_jax_array_(arr) + obj = _as_jax_array_(obj) return Array(jnp.delete(arr, obj, axis=axis)) @wraps(jnp.take_along_axis) def take_along_axis(a, indices, axis, mode=None): - a = _remove_brainpy_array(a) - indices = _remove_brainpy_array(indices) + a = _as_jax_array_(a) + indices = _as_jax_array_(indices) return Array(jnp.take_along_axis(a, indices, axis, mode)) @wraps(jnp.block) def block(arrays): leaves, tree = tree_flatten(arrays, is_leaf=lambda a: isinstance(a, Array)) - leaves = [_remove_brainpy_array(l) for l in leaves] + leaves = [_as_jax_array_(l) for l in leaves] arrays = tree_unflatten(tree, leaves) return Array(jnp.block(arrays)) @wraps(jnp.broadcast_arrays) def broadcast_arrays(*args): - args = [(_remove_brainpy_array(a)) for a in args] + args = [(_as_jax_array_(a)) for a in args] return jnp.broadcast_arrays(args) @@ -251,14 +251,14 @@ def broadcast_arrays(*args): @wraps(jnp.broadcast_to) def broadcast_to(arr, shape): - arr = _remove_brainpy_array(arr) + arr = _as_jax_array_(arr) return Array(jnp.broadcast_to(arr, shape)) @wraps(jnp.compress) def compress(condition, a, axis=None, out=None): - condition = _remove_brainpy_array(condition) - a = _remove_brainpy_array(a) + condition = _as_jax_array_(condition) + a = _as_jax_array_(a) return Array(jnp.compress(condition, a, axis, out)) @@ -273,7 +273,7 @@ def diag_indices(n, ndim=2): @wraps(jnp.diag_indices_from) def diag_indices_from(arr): - arr = _remove_brainpy_array(arr) + arr = _as_jax_array_(arr) res = jnp.diag_indices_from(arr) if isinstance(res, tuple): return tuple(Array(r) for r in res) @@ -283,25 +283,25 @@ def diag_indices_from(arr): @wraps(jnp.diagflat) def diagflat(v, k=0): - v = _remove_brainpy_array(v) + v = _as_jax_array_(v) return Array(jnp.diagflat(v, k)) @wraps(jnp.diagonal) def diagonal(a, offset=0, axis1: int = 0, axis2: int = 1): - a = _remove_brainpy_array(a) + a = _as_jax_array_(a) return Array(jnp.diagonal(a, offset, axis1, axis2)) @wraps(jnp.einsum) def einsum(*operands, out=None, optimize='optimal', precision=None, _use_xeinsum=False): - operands = tuple((_remove_brainpy_array(a)) for a in operands) + operands = tuple((_as_jax_array_(a)) for a in operands) return Array(jnp.einsum(*operands, out=out, optimize=optimize, precision=precision, _use_xeinsum=_use_xeinsum)) @wraps(jnp.einsum_path) def einsum_path(subscripts, *operands, optimize='greedy'): - operands = tuple((_remove_brainpy_array(a)) for a in operands) + operands = tuple((_as_jax_array_(a)) for a in operands) return jnp.einsum_path(subscripts, *operands, optimize=optimize) @@ -312,7 +312,7 @@ def geomspace(start, stop, num=50, endpoint=True, dtype=None, axis: int = 0): @wraps(jnp.gradient) def gradient(f, *varargs, axis=None, edge_order=None): - f = _remove_brainpy_array(f) + f = _as_jax_array_(f) res = jnp.gradient(f, *varargs, axis=axis, edge_order=edge_order) if isinstance(res, (list, tuple)): return list(Array(r) for r in res) @@ -322,35 +322,35 @@ def gradient(f, *varargs, axis=None, edge_order=None): @wraps(jnp.histogram2d) def histogram2d(x, y, bins=10, range=None, weights=None, density=None): - x = _remove_brainpy_array(x) - y = _remove_brainpy_array(y) + x = _as_jax_array_(x) + y = _as_jax_array_(y) H, xedges, yedges = jnp.histogram2d(x, y, bins, range, weights, density) return Array(H), Array(xedges), Array(yedges) @wraps(jnp.histogram_bin_edges) def histogram_bin_edges(a, bins=10, range=None, weights=None): - a = _remove_brainpy_array(a) + a = _as_jax_array_(a) return Array(jnp.histogram_bin_edges(a, bins, range, weights)) @wraps(jnp.histogramdd) def histogramdd(sample, bins=10, range=None, weights=None, density=None): - sample = _remove_brainpy_array(sample) + sample = _as_jax_array_(sample) r = jnp.histogramdd(sample, bins, range, weights, density) return Array(r[0]), r[1] @wraps(jnp.i0) def i0(x): - x = _remove_brainpy_array(x) + x = _as_jax_array_(x) return Array(jnp.i0(x)) @wraps(jnp.in1d) def in1d(ar1, ar2, assume_unique=False, invert=False): - ar1 = _remove_brainpy_array(ar1) - ar2 = _remove_brainpy_array(ar2) + ar1 = _as_jax_array_(ar1) + ar2 = _as_jax_array_(ar2) return Array(jnp.in1d(ar1, ar2, assume_unique, invert)) @@ -366,15 +366,15 @@ def indices(dimensions, dtype=None, sparse=False): @wraps(jnp.insert) def insert(arr, obj, values, axis=None): - arr = _remove_brainpy_array(arr) - values = _remove_brainpy_array(values) + arr = _as_jax_array_(arr) + values = _as_jax_array_(values) return Array(jnp.insert(arr, obj, values, axis)) @wraps(jnp.intersect1d) def intersect1d(ar1, ar2, assume_unique=False, return_indices=False): - ar1 = _remove_brainpy_array(ar1) - ar2 = _remove_brainpy_array(ar2) + ar1 = _as_jax_array_(ar1) + ar2 = _as_jax_array_(ar2) res = jnp.intersect1d(ar1, ar2, assume_unique, return_indices) if return_indices: return tuple([Array(r) for r in res]) @@ -384,27 +384,27 @@ def intersect1d(ar1, ar2, assume_unique=False, return_indices=False): @wraps(jnp.iscomplex) def iscomplex(x): - x = _remove_brainpy_array(x) + x = _as_jax_array_(x) return jnp.iscomplex(x) @wraps(jnp.isin) def isin(element, test_elements, assume_unique=False, invert=False): - element = _remove_brainpy_array(element) - test_elements = _remove_brainpy_array(test_elements) + element = _as_jax_array_(element) + test_elements = _as_jax_array_(test_elements) return Array(jnp.isin(element, test_elements, assume_unique, invert)) @wraps(jnp.ix_) def ix_(*args): - args = [_remove_brainpy_array(a) for a in args] + args = [_as_jax_array_(a) for a in args] return jnp.ix_(*args) @wraps(jnp.lexsort) def lexsort(keys, axis=-1): leaves, tree = tree_flatten(keys, is_leaf=lambda x: isinstance(x, Array)) - leaves = [_remove_brainpy_array(l) for l in leaves] + leaves = [_as_jax_array_(l) for l in leaves] keys = tree_unflatten(tree, leaves) return Array(jnp.lexsort(keys, axis)) @@ -421,14 +421,14 @@ def load(file, mmap_mode=None, allow_pickle=False, fix_imports=True, @wraps(np.save) def save(file, arr, allow_pickle=True, fix_imports=True): - arr = _remove_brainpy_array(arr) + arr = _as_jax_array_(arr) np.save(file, arr, allow_pickle, fix_imports) @wraps(np.savez) def savez(file, *args, **kwds): - args = [_remove_brainpy_array(a) for a in args] - kwds = {k: _remove_brainpy_array(v) for k, v in kwds.items()} + args = [_as_jax_array_(a) for a in args] + kwds = {k: _as_jax_array_(v) for k, v in kwds.items()} np.savez(file, *args, **kwds) @@ -436,56 +436,56 @@ def savez(file, *args, **kwds): def msort(a): - return Array(jnp.sort(_remove_brainpy_array(a), axis=0)) + return Array(jnp.sort(_as_jax_array_(a), axis=0)) @wraps(jnp.nan_to_num) def nan_to_num(x, copy=True, nan=0.0, posinf=None, neginf=None): - x = _remove_brainpy_array(x) + x = _as_jax_array_(x) return Array(jnp.nan_to_num(x, copy, nan=nan, posinf=posinf, neginf=neginf)) @wraps(jnp.nanargmax) def nanargmax(a, axis=None, out=None, keepdims=None): - return Array(jnp.nanargmax(_remove_brainpy_array(a), axis=axis, out=out, keepdims=keepdims)) + return Array(jnp.nanargmax(_as_jax_array_(a), axis=axis, out=out, keepdims=keepdims)) @wraps(jnp.nanargmin) def nanargmin(a, axis=None, out=None, keepdims=None): - return Array(jnp.nanargmin(_remove_brainpy_array(a), axis=axis, out=out, keepdims=keepdims)) + return Array(jnp.nanargmin(_as_jax_array_(a), axis=axis, out=out, keepdims=keepdims)) @wraps(jnp.pad) def pad(array, pad_width, mode="constant", **kwargs): - array = _remove_brainpy_array(array) - pad_width = _remove_brainpy_array(pad_width) - kwargs = {k: _remove_brainpy_array(v) for k, v in kwargs.items()} + array = _as_jax_array_(array) + pad_width = _as_jax_array_(pad_width) + kwargs = {k: _as_jax_array_(v) for k, v in kwargs.items()} return Array(jnp.pad(array, pad_width, mode, **kwargs)) @wraps(jnp.poly) def poly(seq_of_zeros): - seq_of_zeros = _remove_brainpy_array(seq_of_zeros) + seq_of_zeros = _as_jax_array_(seq_of_zeros) return Array(jnp.poly(seq_of_zeros)) @wraps(jnp.polyadd) def polyadd(a1, a2): - a1 = _remove_brainpy_array(a1) - a2 = _remove_brainpy_array(a2) + a1 = _as_jax_array_(a1) + a2 = _as_jax_array_(a2) return Array(jnp.polyadd(a1, a2)) @wraps(jnp.polyder) def polyder(p, m=1): - p = _remove_brainpy_array(p) + p = _as_jax_array_(p) return Array(jnp.polyder(p, m)) @wraps(jnp.polyfit) def polyfit(x, y, deg, rcond=None, full=False, w=None, cov=False): - x = _remove_brainpy_array(x) - y = _remove_brainpy_array(y) + x = _as_jax_array_(x) + y = _as_jax_array_(y) res = jnp.polyfit(x, y, deg, rcond=rcond, full=full, w=w, cov=cov) if isinstance(res, (tuple, list)): return tuple(Array(r) for r in res) @@ -495,99 +495,99 @@ def polyfit(x, y, deg, rcond=None, full=False, w=None, cov=False): @wraps(jnp.polyint) def polyint(p, m=1, k=None): - p = _remove_brainpy_array(p) + p = _as_jax_array_(p) return Array(jnp.polyint(p, m, k)) @wraps(jnp.polymul) def polymul(a1, a2, **kwargs): - a1 = _remove_brainpy_array(a1) - a2 = _remove_brainpy_array(a2) + a1 = _as_jax_array_(a1) + a2 = _as_jax_array_(a2) return Array(jnp.polymul(a1, a2, **kwargs)) @wraps(jnp.polysub) def polysub(a1, a2): - a1 = _remove_brainpy_array(a1) - a2 = _remove_brainpy_array(a2) + a1 = _as_jax_array_(a1) + a2 = _as_jax_array_(a2) return Array(jnp.polysub(a1, a2)) @wraps(jnp.polyval) def polyval(p, x): - p = _remove_brainpy_array(p) - x = _remove_brainpy_array(x) + p = _as_jax_array_(p) + x = _as_jax_array_(x) return Array(jnp.polyval(p, x)) @wraps(jnp.resize) def resize(a, new_shape): - a = _remove_brainpy_array(a) + a = _as_jax_array_(a) return Array(jnp.resize(a, new_shape)) @wraps(jnp.rollaxis) def rollaxis(a, axis: int, start=0): - a = _remove_brainpy_array(a) + a = _as_jax_array_(a) return Array(jnp.rollaxis(a, axis, start)) @wraps(jnp.roots) def roots(p): - p = _remove_brainpy_array(p) + p = _as_jax_array_(p) return Array(jnp.roots(p)) @wraps(jnp.rot90) def rot90(m, k=1, axes=(0, 1)): - m = _remove_brainpy_array(m) + m = _as_jax_array_(m) return Array(jnp.rot90(m, k, axes)) @wraps(jnp.setdiff1d) def setdiff1d(ar1, ar2, assume_unique=False, **kwargs): - return Array(jnp.setdiff1d(_remove_brainpy_array(ar1), - _remove_brainpy_array(ar2), + return Array(jnp.setdiff1d(_as_jax_array_(ar1), + _as_jax_array_(ar2), assume_unique=assume_unique, **kwargs)) @wraps(jnp.setxor1d) def setxor1d(ar1, ar2, assume_unique=False): - return Array(jnp.setxor1d(_remove_brainpy_array(ar1), - _remove_brainpy_array(ar2), + return Array(jnp.setxor1d(_as_jax_array_(ar1), + _as_jax_array_(ar2), assume_unique=assume_unique)) @wraps(jnp.tensordot) def tensordot(a, b, axes=2, **kwargs): - a = _remove_brainpy_array(a) - b = _remove_brainpy_array(b) + a = _as_jax_array_(a) + b = _as_jax_array_(b) return Array(jnp.tensordot(a, b, axes, **kwargs)) @wraps(jnp.trim_zeros) def trim_zeros(filt, trim='fb'): - return Array(jnp.trim_zeros(_remove_brainpy_array(filt), trim)) + return Array(jnp.trim_zeros(_as_jax_array_(filt), trim)) @wraps(jnp.union1d) def union1d(ar1, ar2, **kwargs): - ar1 = _remove_brainpy_array(ar1) - ar2 = _remove_brainpy_array(ar2) + ar1 = _as_jax_array_(ar1) + ar2 = _as_jax_array_(ar2) return Array(jnp.union1d(ar1, ar2, **kwargs)) @wraps(jnp.unravel_index) def unravel_index(indices, shape): - indices = _remove_brainpy_array(indices) - shape = _remove_brainpy_array(shape) + indices = _as_jax_array_(indices) + shape = _as_jax_array_(shape) return jnp.unravel_index(indices, shape) @wraps(jnp.unwrap) def unwrap(p, discont=jnp.pi, axis: int = -1, period: float = 2 * jnp.pi): - p = _remove_brainpy_array(p) + p = _as_jax_array_(p) return Array(jnp.unwrap(p, discont, axis, period)) @@ -597,39 +597,39 @@ def unwrap(p, discont=jnp.pi, axis: int = -1, period: float = 2 * jnp.pi): # 1. Basics @wraps(jnp.isreal) def isreal(x): - x = _remove_brainpy_array(x) + x = _as_jax_array_(x) return jnp.isreal(x) @wraps(jnp.isscalar) def isscalar(x): - x = _remove_brainpy_array(x) + x = _as_jax_array_(x) return jnp.isscalar(x) @wraps(jnp.real) def real(x): - return jnp.real(_remove_brainpy_array(x)) + return jnp.real(_as_jax_array_(x)) @wraps(jnp.imag) def imag(x): - return jnp.imag(_remove_brainpy_array(x)) + return jnp.imag(_as_jax_array_(x)) @wraps(jnp.conj) def conj(x): - return jnp.conj(_remove_brainpy_array(x)) + return jnp.conj(_as_jax_array_(x)) @wraps(jnp.conjugate) def conjugate(x): - return jnp.conjugate(_remove_brainpy_array(x)) + return jnp.conjugate(_as_jax_array_(x)) @wraps(jnp.ndim) def ndim(x): - return jnp.ndim(_remove_brainpy_array(x)) + return jnp.ndim(_as_jax_array_(x)) # 2. Arithmetic operations @@ -640,312 +640,312 @@ def add(x, y): @wraps(jnp.reciprocal) def reciprocal(x): - x = _remove_brainpy_array(x) + x = _as_jax_array_(x) return Array(jnp.reciprocal(x)) @wraps(jnp.negative) def negative(x): - x = _remove_brainpy_array(x) + x = _as_jax_array_(x) return Array(jnp.negative(x)) @wraps(jnp.positive) def positive(x): - x = _remove_brainpy_array(x) + x = _as_jax_array_(x) return Array(jnp.positive(x)) @wraps(jnp.multiply) def multiply(x1, x2): - x1 = _remove_brainpy_array(x1) - x2 = _remove_brainpy_array(x2) + x1 = _as_jax_array_(x1) + x2 = _as_jax_array_(x2) return Array(jnp.multiply(x1, x2)) @wraps(jnp.divide) def divide(x1, x2): - x1 = _remove_brainpy_array(x1) - x2 = _remove_brainpy_array(x2) + x1 = _as_jax_array_(x1) + x2 = _as_jax_array_(x2) return Array(jnp.divide(x1, x2)) @wraps(jnp.power) def power(x1, x2): - x1 = _remove_brainpy_array(x1) - x2 = _remove_brainpy_array(x2) + x1 = _as_jax_array_(x1) + x2 = _as_jax_array_(x2) return Array(jnp.power(x1, x2)) @wraps(jnp.subtract) def subtract(x1, x2): - x1 = _remove_brainpy_array(x1) - x2 = _remove_brainpy_array(x2) + x1 = _as_jax_array_(x1) + x2 = _as_jax_array_(x2) return Array(jnp.subtract(x1, x2)) @wraps(jnp.true_divide) def true_divide(x1, x2): - x1 = _remove_brainpy_array(x1) - x2 = _remove_brainpy_array(x2) + x1 = _as_jax_array_(x1) + x2 = _as_jax_array_(x2) return Array(jnp.true_divide(x1, x2)) @wraps(jnp.floor_divide) def floor_divide(x1, x2): - x1 = _remove_brainpy_array(x1) - x2 = _remove_brainpy_array(x2) + x1 = _as_jax_array_(x1) + x2 = _as_jax_array_(x2) return Array(jnp.floor_divide(x1, x2)) @wraps(jnp.float_power) def float_power(x1, x2): - x1 = _remove_brainpy_array(x1) - x2 = _remove_brainpy_array(x2) + x1 = _as_jax_array_(x1) + x2 = _as_jax_array_(x2) return Array(jnp.float_power(x1, x2)) @wraps(jnp.fmod) def fmod(x1, x2): - x1 = _remove_brainpy_array(x1) - x2 = _remove_brainpy_array(x2) + x1 = _as_jax_array_(x1) + x2 = _as_jax_array_(x2) return Array(jnp.fmod(x1, x2)) @wraps(jnp.mod) def mod(x1, x2): if isinstance(x1, Array): x1 = x1.value - x2 = _remove_brainpy_array(x2) + x2 = _as_jax_array_(x2) return Array(jnp.mod(x1, x2)) @wraps(jnp.divmod) def divmod(x1, x2): - x1 = _remove_brainpy_array(x1) - x2 = _remove_brainpy_array(x2) + x1 = _as_jax_array_(x1) + x2 = _as_jax_array_(x2) r = jnp.divmod(x1, x2) return Array(r[0]), Array(r[1]) @wraps(jnp.remainder) def remainder(x1, x2): - x1 = _remove_brainpy_array(x1) - x2 = _remove_brainpy_array(x2) + x1 = _as_jax_array_(x1) + x2 = _as_jax_array_(x2) return Array(jnp.remainder(x1, x2)) @wraps(jnp.modf) def modf(x): - x = _remove_brainpy_array(x) + x = _as_jax_array_(x) r = jnp.modf(x) return Array(r[0]), Array(r[1]) @wraps(jnp.abs) def abs(x): - x = _remove_brainpy_array(x) + x = _as_jax_array_(x) return Array(jnp.absolute(x)) @wraps(jnp.absolute) def absolute(x): - x = _remove_brainpy_array(x) + x = _as_jax_array_(x) return Array(jnp.absolute(x)) # 3. Exponents and logarithms @wraps(jnp.exp) def exp(x): - x = _remove_brainpy_array(x) + x = _as_jax_array_(x) return Array(jnp.exp(x)) @wraps(jnp.exp2) def exp2(x): - x = _remove_brainpy_array(x) + x = _as_jax_array_(x) return Array(jnp.exp2(x)) @wraps(jnp.expm1) def expm1(x): - x = _remove_brainpy_array(x) + x = _as_jax_array_(x) return Array(jnp.expm1(x)) @wraps(jnp.log) def log(x): - x = _remove_brainpy_array(x) + x = _as_jax_array_(x) return Array(jnp.log(x)) @wraps(jnp.log10) def log10(x): - x = _remove_brainpy_array(x) + x = _as_jax_array_(x) return Array(jnp.log10(x)) @wraps(jnp.log1p) def log1p(x): - x = _remove_brainpy_array(x) + x = _as_jax_array_(x) return Array(jnp.log1p(x)) @wraps(jnp.log2) def log2(x): - x = _remove_brainpy_array(x) + x = _as_jax_array_(x) return Array(jnp.log2(x)) @wraps(jnp.logaddexp) def logaddexp(x1, x2): - x1 = _remove_brainpy_array(x1) - x2 = _remove_brainpy_array(x2) + x1 = _as_jax_array_(x1) + x2 = _as_jax_array_(x2) return Array(jnp.logaddexp(x1, x2)) @wraps(jnp.logaddexp2) def logaddexp2(x1, x2): - x1 = _remove_brainpy_array(x1) - x2 = _remove_brainpy_array(x2) + x1 = _as_jax_array_(x1) + x2 = _as_jax_array_(x2) return Array(jnp.logaddexp2(x1, x2)) # 4. Rational routines @wraps(jnp.lcm) def lcm(x1, x2): - x1 = _remove_brainpy_array(x1) - x2 = _remove_brainpy_array(x2) + x1 = _as_jax_array_(x1) + x2 = _as_jax_array_(x2) return Array(jnp.lcm(x1, x2)) @wraps(jnp.gcd) def gcd(x1, x2): - x1 = _remove_brainpy_array(x1) - x2 = _remove_brainpy_array(x2) + x1 = _as_jax_array_(x1) + x2 = _as_jax_array_(x2) return Array(jnp.gcd(x1, x2)) # 5. trigonometric functions @wraps(jnp.arccos) def arccos(x): - x = _remove_brainpy_array(x) + x = _as_jax_array_(x) return Array(jnp.arccos(x)) @wraps(jnp.arccosh) def arccosh(x): - x = _remove_brainpy_array(x) + x = _as_jax_array_(x) return Array(jnp.arccosh(x)) @wraps(jnp.arcsin) def arcsin(x): - x = _remove_brainpy_array(x) + x = _as_jax_array_(x) return Array(jnp.arcsin(x)) @wraps(jnp.arcsinh) def arcsinh(x): - x = _remove_brainpy_array(x) + x = _as_jax_array_(x) return Array(jnp.arcsinh(x)) @wraps(jnp.arctan) def arctan(x): - x = _remove_brainpy_array(x) + x = _as_jax_array_(x) return Array(jnp.arctan(x)) @wraps(jnp.arctan2) def arctan2(x, y): - x = _remove_brainpy_array(x) - y = _remove_brainpy_array(y) + x = _as_jax_array_(x) + y = _as_jax_array_(y) return Array(jnp.arctan2(x, y)) @wraps(jnp.arctanh) def arctanh(x): - x = _remove_brainpy_array(x) + x = _as_jax_array_(x) return Array(jnp.arctanh(x)) @wraps(jnp.cos) def cos(x): - x = _remove_brainpy_array(x) + x = _as_jax_array_(x) return Array(jnp.cos(x)) @wraps(jnp.cosh) def cosh(x): - x = _remove_brainpy_array(x) + x = _as_jax_array_(x) return Array(jnp.cosh(x)) @wraps(jnp.sin) def sin(x): - x = _remove_brainpy_array(x) + x = _as_jax_array_(x) return Array(jnp.sin(x)) @wraps(jnp.sinc) def sinc(x): - x = _remove_brainpy_array(x) + x = _as_jax_array_(x) return Array(jnp.sinc(x)) @wraps(jnp.sinh) def sinh(x): - x = _remove_brainpy_array(x) + x = _as_jax_array_(x) return Array(jnp.sinh(x)) @wraps(jnp.tan) def tan(x): - x = _remove_brainpy_array(x) + x = _as_jax_array_(x) return Array(jnp.tan(x)) @wraps(jnp.tanh) def tanh(x): - x = _remove_brainpy_array(x) + x = _as_jax_array_(x) return Array(jnp.tanh(x)) @wraps(jnp.deg2rad) def deg2rad(x): - x = _remove_brainpy_array(x) + x = _as_jax_array_(x) return Array(jnp.deg2rad(x)) @wraps(jnp.rad2deg) def rad2deg(x): - x = _remove_brainpy_array(x) + x = _as_jax_array_(x) return Array(jnp.rad2deg(x)) @wraps(jnp.degrees) def degrees(x): - x = _remove_brainpy_array(x) + x = _as_jax_array_(x) return Array(jnp.degrees(x)) @wraps(jnp.radians) def radians(x): - x = _remove_brainpy_array(x) + x = _as_jax_array_(x) return Array(jnp.radians(x)) @wraps(jnp.hypot) def hypot(x1, x2): - x1 = _remove_brainpy_array(x1) - x2 = _remove_brainpy_array(x2) + x1 = _as_jax_array_(x1) + x2 = _as_jax_array_(x2) return Array(jnp.hypot(x1, x2)) # 6. Rounding @wraps(jnp.round) def round(a, decimals=0): - a = _remove_brainpy_array(a) + a = _as_jax_array_(a) return Array(jnp.round(a, decimals=decimals)) @@ -955,31 +955,31 @@ def round(a, decimals=0): @wraps(jnp.rint) def rint(x): - x = _remove_brainpy_array(x) + x = _as_jax_array_(x) return Array(jnp.rint(x)) @wraps(jnp.floor) def floor(x): - x = _remove_brainpy_array(x) + x = _as_jax_array_(x) return Array(jnp.floor(x)) @wraps(jnp.ceil) def ceil(x): - x = _remove_brainpy_array(x) + x = _as_jax_array_(x) return Array(jnp.ceil(x)) @wraps(jnp.trunc) def trunc(x): - x = _remove_brainpy_array(x) + x = _as_jax_array_(x) return Array(jnp.trunc(x)) @wraps(jnp.fix) def fix(x): - x = _remove_brainpy_array(x) + x = _as_jax_array_(x) return Array(jnp.fix(x)) @@ -988,7 +988,7 @@ def fix(x): @wraps(jnp.prod) def prod(a, axis=None, dtype=None, keepdims=None, initial=None, where=None, **kwargs): - a = _remove_brainpy_array(a) + a = _as_jax_array_(a) r = jnp.prod(a, axis=axis, dtype=dtype, keepdims=keepdims, initial=initial, where=where, **kwargs) return r if axis is None else Array(r) @@ -998,39 +998,39 @@ def prod(a, axis=None, dtype=None, keepdims=None, initial=None, where=None, **kw @wraps(jnp.sum) def sum(a, axis=None, dtype=None, keepdims=None, initial=None, where=None, **kwargs): - a = _remove_brainpy_array(a) + a = _as_jax_array_(a) r = jnp.sum(a, axis=axis, dtype=dtype, keepdims=keepdims, initial=initial, where=where, **kwargs) return r if axis is None else Array(r) @wraps(jnp.diff) def diff(a, n=1, axis: int = -1, prepend=None, append=None): - a = _remove_brainpy_array(a) + a = _as_jax_array_(a) return Array(jnp.diff(a, n=n, axis=axis, prepend=prepend, append=append)) @wraps(jnp.median) def median(a, axis=None, keepdims=False, **kwargs): - a = _remove_brainpy_array(a) + a = _as_jax_array_(a) r = jnp.median(a, axis=axis, keepdims=keepdims, **kwargs) return r if axis is None else Array(r) @wraps(jnp.nancumprod) def nancumprod(a, axis=None, dtype=None): - a = _remove_brainpy_array(a) + a = _as_jax_array_(a) return Array(jnp.nancumprod(a=a, axis=axis, dtype=dtype)) @wraps(jnp.nancumsum) def nancumsum(a, axis=None, dtype=None): - a = _remove_brainpy_array(a) + a = _as_jax_array_(a) return Array(jnp.nancumsum(a=a, axis=axis, dtype=dtype)) @wraps(jnp.cumprod) def cumprod(a, axis=None, dtype=None): - a = _remove_brainpy_array(a) + a = _as_jax_array_(a) return Array(jnp.cumprod(a=a, axis=axis, dtype=dtype)) @@ -1039,95 +1039,95 @@ def cumprod(a, axis=None, dtype=None): @wraps(jnp.cumsum) def cumsum(a, axis=None, dtype=None): - a = _remove_brainpy_array(a) + a = _as_jax_array_(a) return Array(jnp.cumsum(a=a, axis=axis, dtype=dtype)) @wraps(jnp.nanprod) def nanprod(a, axis=None, dtype=None, keepdims=None, **kwargs): - a = _remove_brainpy_array(a) + a = _as_jax_array_(a) r = jnp.nanprod(a=a, axis=axis, dtype=dtype, keepdims=keepdims, **kwargs) return r if axis is None else Array(r) @wraps(jnp.nansum) def nansum(a, axis=None, dtype=None, keepdims=None, **kwargs): - a = _remove_brainpy_array(a) + a = _as_jax_array_(a) r = jnp.nansum(a=a, axis=axis, dtype=dtype, keepdims=keepdims, **kwargs) return r if axis is None else Array(r) @wraps(jnp.ediff1d) def ediff1d(a, to_end=None, to_begin=None): - a = _remove_brainpy_array(a) - to_end = _remove_brainpy_array(to_end) - to_begin = _remove_brainpy_array(to_begin) + a = _as_jax_array_(a) + to_end = _as_jax_array_(to_end) + to_begin = _as_jax_array_(to_begin) return Array(jnp.ediff1d(a, to_end=to_end, to_begin=to_begin)) @wraps(jnp.cross) def cross(a, b, axisa=-1, axisb=-1, axisc=-1, axis=None): - a = _remove_brainpy_array(a) - b = _remove_brainpy_array(b) + a = _as_jax_array_(a) + b = _as_jax_array_(b) return Array(jnp.cross(a, b, axisa=axisa, axisb=axisb, axisc=axisc, axis=axis)) @wraps(jnp.trapz) def trapz(y, x=None, dx=1.0, axis: int = -1): - y = _remove_brainpy_array(y) - x = _remove_brainpy_array(x) + y = _as_jax_array_(y) + x = _as_jax_array_(x) return jnp.trapz(y, x=x, dx=dx, axis=axis) # 8. floating_functions @wraps(jnp.isfinite) def isfinite(x): - x = _remove_brainpy_array(x) + x = _as_jax_array_(x) return Array(jnp.isfinite(x)) @wraps(jnp.isinf) def isinf(x): - x = _remove_brainpy_array(x) + x = _as_jax_array_(x) return Array(jnp.isinf(x)) @wraps(jnp.isnan) def isnan(x): - x = _remove_brainpy_array(x) + x = _as_jax_array_(x) return Array(jnp.isnan(x)) @wraps(jnp.signbit) def signbit(x): - x = _remove_brainpy_array(x) + x = _as_jax_array_(x) return Array(jnp.signbit(x)) @wraps(jnp.nextafter) def nextafter(x1, x2): - x1 = _remove_brainpy_array(x1) - x2 = _remove_brainpy_array(x2) + x1 = _as_jax_array_(x1) + x2 = _as_jax_array_(x2) return Array(jnp.nextafter(x1, x2)) @wraps(jnp.copysign) def copysign(x1, x2): - x1 = _remove_brainpy_array(x1) - x2 = _remove_brainpy_array(x2) + x1 = _as_jax_array_(x1) + x2 = _as_jax_array_(x2) return Array(jnp.copysign(x1, x2)) @wraps(jnp.ldexp) def ldexp(x1, x2): - x1 = _remove_brainpy_array(x1) - x2 = _remove_brainpy_array(x2) + x1 = _as_jax_array_(x1) + x2 = _as_jax_array_(x2) return Array(jnp.ldexp(x1, x2)) @wraps(jnp.frexp) def frexp(x): - x = _remove_brainpy_array(x) + x = _as_jax_array_(x) mantissa, exponent = jnp.frexp(x) return Array(mantissa), Array(exponent) @@ -1135,95 +1135,95 @@ def frexp(x): # 9. Miscellaneous @wraps(jnp.convolve) def convolve(a, v, mode='full', **kwargs): - a = _remove_brainpy_array(a) - v = _remove_brainpy_array(v) + a = _as_jax_array_(a) + v = _as_jax_array_(v) return Array(jnp.convolve(a, v, mode, **kwargs)) @wraps(jnp.sqrt) def sqrt(x): - x = _remove_brainpy_array(x) + x = _as_jax_array_(x) return Array(jnp.sqrt(x)) @wraps(jnp.cbrt) def cbrt(x): - x = _remove_brainpy_array(x) + x = _as_jax_array_(x) return Array(jnp.cbrt(x)) @wraps(jnp.square) def square(x): - x = _remove_brainpy_array(x) + x = _as_jax_array_(x) return Array(jnp.square(x)) @wraps(jnp.fabs) def fabs(x): - x = _remove_brainpy_array(x) + x = _as_jax_array_(x) return Array(jnp.fabs(x)) @wraps(jnp.sign) def sign(x): - x = _remove_brainpy_array(x) + x = _as_jax_array_(x) return Array(jnp.sign(x)) @wraps(jnp.heaviside) def heaviside(x1, x2): - x1 = _remove_brainpy_array(x1) - x2 = _remove_brainpy_array(x2) + x1 = _as_jax_array_(x1) + x2 = _as_jax_array_(x2) return Array(jnp.heaviside(x1, x2)) @wraps(jnp.maximum) def maximum(x1, x2): - x1 = _remove_brainpy_array(x1) - x2 = _remove_brainpy_array(x2) + x1 = _as_jax_array_(x1) + x2 = _as_jax_array_(x2) return Array(jnp.maximum(x1, x2)) @wraps(jnp.minimum) def minimum(x1, x2): - x1 = _remove_brainpy_array(x1) - x2 = _remove_brainpy_array(x2) + x1 = _as_jax_array_(x1) + x2 = _as_jax_array_(x2) return Array(jnp.minimum(x1, x2)) @wraps(jnp.fmax) def fmax(x1, x2): - x1 = _remove_brainpy_array(x1) - x2 = _remove_brainpy_array(x2) + x1 = _as_jax_array_(x1) + x2 = _as_jax_array_(x2) return Array(jnp.fmax(x1, x2)) @wraps(jnp.fmin) def fmin(x1, x2): - x1 = _remove_brainpy_array(x1) - x2 = _remove_brainpy_array(x2) + x1 = _as_jax_array_(x1) + x2 = _as_jax_array_(x2) return Array(jnp.fmin(x1, x2)) @wraps(jnp.interp) def interp(x, xp, fp, left=None, right=None, period=None): - x = _remove_brainpy_array(x) - xp = _remove_brainpy_array(xp) - fp = _remove_brainpy_array(fp) + x = _as_jax_array_(x) + xp = _as_jax_array_(xp) + fp = _as_jax_array_(fp) return Array(jnp.interp(x, xp, fp, left=left, right=right, period=period)) @wraps(jnp.clip) def clip(a, a_min=None, a_max=None): - a = _remove_brainpy_array(a) - a_min = _remove_brainpy_array(a_min) - a_max = _remove_brainpy_array(a_max) + a = _as_jax_array_(a) + a_min = _as_jax_array_(a_min) + a_max = _as_jax_array_(a_max) return Array(jnp.clip(a, a_min, a_max)) @wraps(jnp.angle) def angle(z, deg=False): - z = _remove_brainpy_array(z) + z = _as_jax_array_(z) a = jnp.angle(z) if deg: a *= 180 / pi @@ -1236,48 +1236,48 @@ def angle(z, deg=False): @wraps(jnp.bitwise_not) def bitwise_not(x): - x = _remove_brainpy_array(x) + x = _as_jax_array_(x) return Array(jnp.bitwise_not(x)) @wraps(jnp.invert) def invert(x): - x = _remove_brainpy_array(x) + x = _as_jax_array_(x) return Array(jnp.invert(x)) @wraps(jnp.bitwise_and) def bitwise_and(x1, x2): - x1 = _remove_brainpy_array(x1) - x2 = _remove_brainpy_array(x2) + x1 = _as_jax_array_(x1) + x2 = _as_jax_array_(x2) return Array(jnp.bitwise_and(x1, x2)) @wraps(jnp.bitwise_or) def bitwise_or(x1, x2): - x1 = _remove_brainpy_array(x1) - x2 = _remove_brainpy_array(x2) + x1 = _as_jax_array_(x1) + x2 = _as_jax_array_(x2) return Array(jnp.bitwise_or(x1, x2)) @wraps(jnp.bitwise_xor) def bitwise_xor(x1, x2): - x1 = _remove_brainpy_array(x1) - x2 = _remove_brainpy_array(x2) + x1 = _as_jax_array_(x1) + x2 = _as_jax_array_(x2) return Array(jnp.bitwise_xor(x1, x2)) @wraps(jnp.left_shift) def left_shift(x1, x2): - x1 = _remove_brainpy_array(x1) - x2 = _remove_brainpy_array(x2) + x1 = _as_jax_array_(x1) + x2 = _as_jax_array_(x2) return Array(jnp.left_shift(x1, x2)) @wraps(jnp.right_shift) def right_shift(x1, x2): - x1 = _remove_brainpy_array(x1) - x2 = _remove_brainpy_array(x2) + x1 = _as_jax_array_(x1) + x2 = _as_jax_array_(x2) return Array(jnp.right_shift(x1, x2)) @@ -1287,106 +1287,106 @@ def right_shift(x1, x2): # 1. Comparison @wraps(jnp.equal) def equal(x1, x2): - x1 = _remove_brainpy_array(x1) - x2 = _remove_brainpy_array(x2) + x1 = _as_jax_array_(x1) + x2 = _as_jax_array_(x2) return Array(jnp.equal(x1, x2)) @wraps(jnp.not_equal) def not_equal(x1, x2): - x1 = _remove_brainpy_array(x1) - x2 = _remove_brainpy_array(x2) + x1 = _as_jax_array_(x1) + x2 = _as_jax_array_(x2) return Array(jnp.not_equal(x1, x2)) @wraps(jnp.greater) def greater(x1, x2): - x1 = _remove_brainpy_array(x1) - x2 = _remove_brainpy_array(x2) + x1 = _as_jax_array_(x1) + x2 = _as_jax_array_(x2) return Array(jnp.greater(x1, x2)) @wraps(jnp.greater_equal) def greater_equal(x1, x2): - x1 = _remove_brainpy_array(x1) - x2 = _remove_brainpy_array(x2) + x1 = _as_jax_array_(x1) + x2 = _as_jax_array_(x2) return Array(jnp.greater_equal(x1, x2)) @wraps(jnp.less) def less(x1, x2): - x1 = _remove_brainpy_array(x1) - x2 = _remove_brainpy_array(x2) + x1 = _as_jax_array_(x1) + x2 = _as_jax_array_(x2) return Array(jnp.less(x1, x2)) @wraps(jnp.less_equal) def less_equal(x1, x2): - x1 = _remove_brainpy_array(x1) - x2 = _remove_brainpy_array(x2) + x1 = _as_jax_array_(x1) + x2 = _as_jax_array_(x2) return Array(jnp.less_equal(x1, x2)) @wraps(jnp.array_equal) def array_equal(a, b, equal_nan=False): - a = _remove_brainpy_array(a) - b = _remove_brainpy_array(b) + a = _as_jax_array_(a) + b = _as_jax_array_(b) return jnp.array_equal(a, b, equal_nan=equal_nan) @wraps(jnp.isclose) def isclose(a, b, rtol=1e-05, atol=1e-08, equal_nan=False): - a = _remove_brainpy_array(a) - b = _remove_brainpy_array(b) + a = _as_jax_array_(a) + b = _as_jax_array_(b) return Array(jnp.isclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan)) @wraps(jnp.allclose) def allclose(a, b, rtol=1e-05, atol=1e-08, equal_nan=False): - a = _remove_brainpy_array(a) - b = _remove_brainpy_array(b) + a = _as_jax_array_(a) + b = _as_jax_array_(b) return jnp.allclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan) # 2. Logical operations @wraps(jnp.logical_not) def logical_not(x): - x = _remove_brainpy_array(x) + x = _as_jax_array_(x) return Array(jnp.logical_not(x)) @wraps(jnp.logical_and) def logical_and(x1, x2): - x1 = _remove_brainpy_array(x1) - x2 = _remove_brainpy_array(x2) + x1 = _as_jax_array_(x1) + x2 = _as_jax_array_(x2) return Array(jnp.logical_and(x1, x2)) @wraps(jnp.logical_or) def logical_or(x1, x2): - x1 = _remove_brainpy_array(x1) - x2 = _remove_brainpy_array(x2) + x1 = _as_jax_array_(x1) + x2 = _as_jax_array_(x2) return Array(jnp.logical_or(x1, x2)) @wraps(jnp.logical_xor) def logical_xor(x1, x2): - x1 = _remove_brainpy_array(x1) - x2 = _remove_brainpy_array(x2) + x1 = _as_jax_array_(x1) + x2 = _as_jax_array_(x2) return Array(jnp.logical_xor(x1, x2)) # 3. Truth value testing @wraps(jnp.all) def all(a, axis=None, keepdims=None, where=None): - a = _remove_brainpy_array(a) + a = _as_jax_array_(a) r = jnp.all(a=a, axis=axis, keepdims=keepdims, where=where) return r if axis is None else Array(r) @wraps(jnp.any) def any(a, axis=None, keepdims=None, where=None): - a = _remove_brainpy_array(a) + a = _as_jax_array_(a) r = jnp.any(a=a, axis=axis, keepdims=keepdims, where=where) return r if axis is None else Array(r) @@ -1401,62 +1401,62 @@ def any(a, axis=None, keepdims=None, where=None): @wraps(jnp.shape) def shape(x): - x = _remove_brainpy_array(x) + x = _as_jax_array_(x) return jnp.shape(x) @wraps(jnp.size) def size(x, axis=None): - x = _remove_brainpy_array(x) + x = _as_jax_array_(x) r = jnp.size(x, axis=axis) return r if axis is None else Array(r) @wraps(jnp.reshape) def reshape(x, newshape, order="C"): - x = _remove_brainpy_array(x) + x = _as_jax_array_(x) return Array(jnp.reshape(x, newshape, order=order)) @wraps(jnp.ravel) def ravel(x, order="C"): - x = _remove_brainpy_array(x) + x = _as_jax_array_(x) return Array(jnp.ravel(x, order=order)) @wraps(jnp.moveaxis) def moveaxis(x, source, destination): - x = _remove_brainpy_array(x) + x = _as_jax_array_(x) return Array(jnp.moveaxis(x, source, destination)) @wraps(jnp.transpose) def transpose(x, axis=None): - x = _remove_brainpy_array(x) + x = _as_jax_array_(x) return Array(jnp.transpose(x, axes=axis)) @wraps(jnp.swapaxes) def swapaxes(x, axis1, axis2): - x = _remove_brainpy_array(x) + x = _as_jax_array_(x) return Array(jnp.swapaxes(x, axis1, axis2)) @wraps(jnp.concatenate) def concatenate(arrays, axis: int = 0): - arrays = [_remove_brainpy_array(a) for a in arrays] + arrays = [_as_jax_array_(a) for a in arrays] return Array(jnp.concatenate(arrays, axis)) @wraps(jnp.stack) def stack(arrays, axis: int = 0): - arrays = [_remove_brainpy_array(a) for a in arrays] + arrays = [_as_jax_array_(a) for a in arrays] return Array(jnp.stack(arrays, axis)) @wraps(jnp.vstack) def vstack(arrays): - arrays = [_remove_brainpy_array(a) for a in arrays] + arrays = [_as_jax_array_(a) for a in arrays] return Array(jnp.vstack(arrays)) @@ -1465,19 +1465,19 @@ def vstack(arrays): @wraps(jnp.hstack) def hstack(arrays): - arrays = [_remove_brainpy_array(a) for a in arrays] + arrays = [_as_jax_array_(a) for a in arrays] return Array(jnp.hstack(arrays)) @wraps(jnp.dstack) def dstack(arrays): - arrays = [_remove_brainpy_array(a) for a in arrays] + arrays = [_as_jax_array_(a) for a in arrays] return Array(jnp.dstack(arrays)) @wraps(jnp.column_stack) def column_stack(arrays): - arrays = [_remove_brainpy_array(a) for a in arrays] + arrays = [_as_jax_array_(a) for a in arrays] return Array(jnp.column_stack(arrays)) @@ -1505,20 +1505,20 @@ def vsplit(ary, indices_or_sections): @wraps(jnp.tile) def tile(A, reps): - A = _remove_brainpy_array(A) + A = _as_jax_array_(A) return Array(jnp.tile(A, reps)) @wraps(jnp.repeat) def repeat(x, repeats, axis=None, **kwargs): - x = _remove_brainpy_array(x) + x = _as_jax_array_(x) return Array(jnp.repeat(x, repeats=repeats, axis=axis, **kwargs)) @wraps(jnp.unique) def unique(x, return_index=False, return_inverse=False, return_counts=False, axis=None, **kwargs): - x = _remove_brainpy_array(x) + x = _as_jax_array_(x) res = jnp.unique(x, return_index=return_index, return_inverse=return_inverse, @@ -1533,112 +1533,112 @@ def unique(x, return_index=False, return_inverse=False, @wraps(jnp.append) def append(arr, values, axis=None): - arr = _remove_brainpy_array(arr) - values = _remove_brainpy_array(values) + arr = _as_jax_array_(arr) + values = _as_jax_array_(values) return Array(jnp.append(arr, values, axis=axis)) @wraps(jnp.flip) def flip(x, axis=None): - x = _remove_brainpy_array(x) + x = _as_jax_array_(x) return Array(jnp.flip(x, axis=axis)) @wraps(jnp.fliplr) def fliplr(x): - x = _remove_brainpy_array(x) + x = _as_jax_array_(x) return Array(jnp.fliplr(x)) @wraps(jnp.flipud) def flipud(x): - x = _remove_brainpy_array(x) + x = _as_jax_array_(x) return Array(jnp.flipud(x)) @wraps(jnp.roll) def roll(x, shift, axis=None): - x = _remove_brainpy_array(x) + x = _as_jax_array_(x) return Array(jnp.roll(x, shift, axis=axis)) @wraps(jnp.atleast_1d) def atleast_1d(*arys): - return jnp.atleast_1d(*[_remove_brainpy_array(a) for a in arys]) + return jnp.atleast_1d(*[_as_jax_array_(a) for a in arys]) @wraps(jnp.atleast_2d) def atleast_2d(*arys): - return jnp.atleast_2d(*[_remove_brainpy_array(a) for a in arys]) + return jnp.atleast_2d(*[_as_jax_array_(a) for a in arys]) @wraps(jnp.atleast_3d) def atleast_3d(*arys): - return jnp.atleast_3d(*[_remove_brainpy_array(a) for a in arys]) + return jnp.atleast_3d(*[_as_jax_array_(a) for a in arys]) @wraps(jnp.expand_dims) def expand_dims(x, axis): - x = _remove_brainpy_array(x) + x = _as_jax_array_(x) return Array(jnp.expand_dims(x, axis=axis)) @wraps(jnp.squeeze) def squeeze(x, axis=None): - x = _remove_brainpy_array(x) + x = _as_jax_array_(x) return Array(jnp.squeeze(x, axis=axis)) @wraps(jnp.sort) def sort(x, axis=-1, kind='quicksort', order=None): - x = _remove_brainpy_array(x) + x = _as_jax_array_(x) return Array(jnp.sort(x, axis=axis, kind=kind, order=order)) @wraps(jnp.argsort) def argsort(x, axis=-1, kind='stable', order=None): - x = _remove_brainpy_array(x) + x = _as_jax_array_(x) return Array(jnp.argsort(x, axis=axis, kind=kind, order=order)) @wraps(jnp.argmax) def argmax(x, axis=None, **kwargs): - x = _remove_brainpy_array(x) + x = _as_jax_array_(x) r = jnp.argmax(x, axis=axis, **kwargs) return r if axis is None else Array(r) @wraps(jnp.argmin) def argmin(x, axis=None, **kwargs): - x = _remove_brainpy_array(x) + x = _as_jax_array_(x) r = jnp.argmin(x, axis=axis, **kwargs) return r if axis is None else Array(r) @wraps(jnp.argwhere) def argwhere(x, **kwargs): - x = _remove_brainpy_array(x) + x = _as_jax_array_(x) return Array(jnp.argwhere(x, **kwargs)) @wraps(jnp.nonzero) def nonzero(x, **kwargs): - x = _remove_brainpy_array(x) + x = _as_jax_array_(x) res = jnp.nonzero(x, **kwargs) return tuple([Array(r) for r in res]) if isinstance(res, tuple) else Array(res) @wraps(jnp.flatnonzero) def flatnonzero(x, **kwargs): - x = _remove_brainpy_array(x) + x = _as_jax_array_(x) return Array(jnp.flatnonzero(x, **kwargs)) @wraps(jnp.where) def where(condition, x=None, y=None, **kwargs): - condition = _remove_brainpy_array(condition) - x = _remove_brainpy_array(x) - y = _remove_brainpy_array(y) + condition = _as_jax_array_(condition) + x = _as_jax_array_(x) + y = _as_jax_array_(y) res = jnp.where(condition, x=x, y=y, **kwargs) if isinstance(res, tuple): return tuple(Array(r) for r in res) @@ -1648,34 +1648,34 @@ def where(condition, x=None, y=None, **kwargs): @wraps(jnp.searchsorted) def searchsorted(a, v, side='left', sorter=None): - a = _remove_brainpy_array(a) - v = _remove_brainpy_array(v) + a = _as_jax_array_(a) + v = _as_jax_array_(v) return Array(jnp.searchsorted(a, v, side=side, sorter=sorter)) @wraps(jnp.extract) def extract(condition, arr): - condition = _remove_brainpy_array(condition) - arr = _remove_brainpy_array(arr) + condition = _as_jax_array_(condition) + arr = _as_jax_array_(arr) return Array(jnp.extract(condition, arr)) @wraps(jnp.count_nonzero) def count_nonzero(a, axis=None, keepdims=False): - a = _remove_brainpy_array(a) + a = _as_jax_array_(a) return jnp.count_nonzero(a, axis=axis, keepdims=keepdims) @wraps(jnp.max) def max(a, axis=None, out=None, keepdims=None, initial=None, where=None): - a = _remove_brainpy_array(a) + a = _as_jax_array_(a) r = jnp.max(a, axis=axis, out=out, keepdims=keepdims, initial=initial, where=where) return r if axis is None else Array(r) @wraps(jnp.min) def min(a, axis=None, out=None, keepdims=None, initial=None, where=None): - a = _remove_brainpy_array(a) + a = _as_jax_array_(a) r = jnp.min(a, axis=axis, out=out, keepdims=keepdims, initial=initial, where=where) return r if axis is None else Array(r) @@ -1686,13 +1686,13 @@ def min(a, axis=None, out=None, keepdims=None, initial=None, where=None): @wraps(jnp.apply_along_axis) def apply_along_axis(func1d, axis: int, arr, *args, **kwargs): - arr = _remove_brainpy_array(arr) + arr = _as_jax_array_(arr) return jnp.apply_along_axis(func1d, axis, arr, *args, **kwargs) @wraps(jnp.apply_over_axes) def apply_over_axes(func, a, axes): - a = _remove_brainpy_array(a) + a = _as_jax_array_(a) return jnp.apply_over_axes(func, a, axes) @@ -1712,23 +1712,23 @@ def array_equiv(a1, a2): @wraps(jnp.array_repr) def array_repr(arr, max_line_width=None, precision=None, suppress_small=None): - arr = _remove_brainpy_array(arr) + arr = _as_jax_array_(arr) return jnp.array_repr(arr, max_line_width=max_line_width, precision=precision, suppress_small=suppress_small) @wraps(jnp.array_str) def array_str(a, max_line_width=None, precision=None, suppress_small=None): - a = _remove_brainpy_array(a) + a = _as_jax_array_(a) return jnp.array_str(a, max_line_width=max_line_width, precision=precision, suppress_small=suppress_small) @wraps(jnp.array_split) def array_split(ary, indices_or_sections, axis: int = 0): - ary = _remove_brainpy_array(ary) + ary = _as_jax_array_(ary) if isinstance(indices_or_sections, Array): indices_or_sections = indices_or_sections.value elif isinstance(indices_or_sections, (tuple, list)): - indices_or_sections = [_remove_brainpy_array(i) for i in indices_or_sections] + indices_or_sections = [_as_jax_array_(i) for i in indices_or_sections] return tuple([Array(a) for a in jnp.array_split(ary, indices_or_sections, axis)]) @@ -1756,25 +1756,25 @@ def empty(shape, dtype=None): @wraps(jnp.zeros_like) def zeros_like(a, dtype=None, shape=None): - a = _remove_brainpy_array(a) + a = _as_jax_array_(a) return Array(jnp.zeros_like(a, dtype=dtype, shape=shape)) @wraps(jnp.ones_like) def ones_like(a, dtype=None, shape=None): - a = _remove_brainpy_array(a) + a = _as_jax_array_(a) return Array(jnp.ones_like(a, dtype=dtype, shape=shape)) @wraps(jnp.empty_like) def empty_like(a, dtype=None, shape=None): - a = _remove_brainpy_array(a) + a = _as_jax_array_(a) return Array(jnp.zeros_like(a, dtype=dtype, shape=shape)) @wraps(jnp.full_like) def full_like(a, fill_value, dtype=None, shape=None): - a = _remove_brainpy_array(a) + a = _as_jax_array_(a) return Array(jnp.full_like(a, fill_value, dtype=dtype, shape=shape)) @@ -1790,12 +1790,12 @@ def identity(n, dtype=None): @wraps(jnp.array) def array(a, dtype=None, copy=True, order="K", ndmin=0) -> Array: - a = _remove_brainpy_array(a) + a = _as_jax_array_(a) try: res = jnp.array(a, dtype=dtype, copy=copy, order=order, ndmin=ndmin) except TypeError: leaves, tree = tree_flatten(a, is_leaf=lambda a: isinstance(a, Array)) - leaves = [_remove_brainpy_array(l) for l in leaves] + leaves = [_as_jax_array_(l) for l in leaves] a = tree_unflatten(tree, leaves) res = jnp.array(a, dtype=dtype, copy=copy, order=order, ndmin=ndmin) return Array(res) @@ -1820,12 +1820,12 @@ def asarray(a, dtype=None, order=None): ArrayType interpretation of `a`. No copy is performed if the input is already an ndarray with matching dtype. """ - a = _remove_brainpy_array(a) + a = _as_jax_array_(a) try: res = jnp.asarray(a=a, dtype=dtype, order=order) except TypeError: leaves, tree = tree_flatten(a, is_leaf=lambda a: isinstance(a, Array)) - leaves = [_remove_brainpy_array(l) for l in leaves] + leaves = [_as_jax_array_(l) for l in leaves] arrays = tree_unflatten(tree, leaves) res = jnp.asarray(a=arrays, dtype=dtype, order=order) return Array(res) @@ -1833,15 +1833,15 @@ def asarray(a, dtype=None, order=None): @wraps(jnp.arange) def arange(*args, **kwargs): - args = [_remove_brainpy_array(a) for a in args] - kwargs = {k: _remove_brainpy_array(v) for k, v in kwargs.items()} + args = [_as_jax_array_(a) for a in args] + kwargs = {k: _as_jax_array_(v) for k, v in kwargs.items()} return Array(jnp.arange(*args, **kwargs)) @wraps(jnp.linspace) def linspace(*args, **kwargs): - args = [_remove_brainpy_array(a) for a in args] - kwargs = {k: _remove_brainpy_array(v) for k, v in kwargs.items()} + args = [_as_jax_array_(a) for a in args] + kwargs = {k: _as_jax_array_(v) for k, v in kwargs.items()} res = jnp.linspace(*args, **kwargs) if isinstance(res, tuple): return Array(res[0]), res[1] @@ -1851,21 +1851,21 @@ def linspace(*args, **kwargs): @wraps(jnp.logspace) def logspace(*args, **kwargs): - args = [_remove_brainpy_array(a) for a in args] - kwargs = {k: _remove_brainpy_array(v) for k, v in kwargs.items()} + args = [_as_jax_array_(a) for a in args] + kwargs = {k: _as_jax_array_(v) for k, v in kwargs.items()} return Array(jnp.logspace(*args, **kwargs)) @wraps(jnp.meshgrid) def meshgrid(*xi, copy=True, sparse=False, indexing='xy'): - xi = [_remove_brainpy_array(x) for x in xi] + xi = [_as_jax_array_(x) for x in xi] rr = jnp.meshgrid(*xi, copy=copy, sparse=sparse, indexing=indexing) return list(Array(r) for r in rr) @wraps(jnp.diag) def diag(a, k=0): - a = _remove_brainpy_array(a) + a = _as_jax_array_(a) return Array(jnp.diag(a, k)) @@ -1876,19 +1876,19 @@ def tri(N, M=None, k=0, dtype=None): @wraps(jnp.tril) def tril(a, k=0): - a = _remove_brainpy_array(a) + a = _as_jax_array_(a) return Array(jnp.tril(a, k)) @wraps(jnp.triu) def triu(a, k=0): - a = _remove_brainpy_array(a) + a = _as_jax_array_(a) return Array(jnp.triu(a, k)) @wraps(jnp.vander) def vander(x, N=None, increasing=False): - x = _remove_brainpy_array(x) + x = _as_jax_array_(x) return Array(jnp.vander(x, N=N, increasing=increasing)) @@ -1898,7 +1898,7 @@ def fill_diagonal(a, val): raise ValueError(f'Must be a ArrayType, but got {type(a)}') if a.ndim < 2: raise ValueError(f'Only support tensor has dimension >= 2, but got {a.shape}') - val = _remove_brainpy_array(val) + val = _as_jax_array_(val) i, j = jnp.diag_indices(_min(a.shape[-2:])) a._value = a.value.at[..., i, j].set(val) @@ -1912,7 +1912,7 @@ def fill_diagonal(a, val): @wraps(jnp.tril_indices_from) def tril_indices_from(x, k=0): - x = _remove_brainpy_array(x) + x = _as_jax_array_(x) res = jnp.tril_indices_from(x, k=k) if isinstance(res, tuple): return tuple(Array(r) for r in res) @@ -1922,7 +1922,7 @@ def tril_indices_from(x, k=0): @wraps(jnp.triu_indices_from) def triu_indices_from(x, k=0): - x = _remove_brainpy_array(x) + x = _as_jax_array_(x) res = jnp.triu_indices_from(x, k=k) if isinstance(res, tuple): return tuple(Array(r) for r in res) @@ -1932,15 +1932,15 @@ def triu_indices_from(x, k=0): @wraps(jnp.take) def take(x, indices, axis=None, mode=None): - x = _remove_brainpy_array(x) - indices = _remove_brainpy_array(indices) + x = _as_jax_array_(x) + indices = _as_jax_array_(indices) return Array(jnp.take(x, indices=indices, axis=axis, mode=mode)) @wraps(jnp.select) def select(condlist, choicelist, default=0): - condlist = [_remove_brainpy_array(c) for c in condlist] - choicelist = [_remove_brainpy_array(c) for c in choicelist] + condlist = [_as_jax_array_(c) for c in condlist] + choicelist = [_as_jax_array_(c) for c in choicelist] return Array(jnp.select(condlist, choicelist, default=default)) @@ -1948,21 +1948,21 @@ def select(condlist, choicelist, default=0): # --------------- @wraps(jnp.nanmin) def nanmin(x, axis=None, keepdims=None, **kwargs): - x = _remove_brainpy_array(x) + x = _as_jax_array_(x) r = jnp.nanmin(x, axis=axis, keepdims=keepdims, **kwargs) return r if axis is None else Array(r) @wraps(jnp.nanmax) def nanmax(x, axis=None, keepdims=None, **kwargs): - x = _remove_brainpy_array(x) + x = _as_jax_array_(x) r = jnp.nanmax(x, axis=axis, keepdims=keepdims, **kwargs) return r if axis is None else Array(r) @wraps(jnp.ptp) def ptp(x, axis=None, keepdims=None): - x = _remove_brainpy_array(x) + x = _as_jax_array_(x) r = jnp.ptp(x, axis=axis, keepdims=keepdims) return r if axis is None else Array(r) @@ -1976,8 +1976,8 @@ def percentile(a, method: str = "linear", keepdims: bool = False, interpolation=None): - a = _remove_brainpy_array(a) - q = _remove_brainpy_array(q) + a = _as_jax_array_(a) + q = _as_jax_array_(q) r = jnp.percentile(a=a, q=q, axis=axis, @@ -1998,8 +1998,8 @@ def nanpercentile(a, method: str = "linear", keepdims: bool = False, interpolation=None): - a = _remove_brainpy_array(a) - q = _remove_brainpy_array(q) + a = _as_jax_array_(a) + q = _as_jax_array_(q) r = jnp.nanpercentile(a=a, q=q, axis=axis, @@ -2015,8 +2015,8 @@ def nanpercentile(a, def quantile(a, q, axis=None, out=None, overwrite_input: bool = False, method: str = "linear", keepdims: bool = False, interpolation=None): - a = _remove_brainpy_array(a) - q = _remove_brainpy_array(q) + a = _as_jax_array_(a) + q = _as_jax_array_(q) r = jnp.quantile(a=a, q=q, axis=axis, out=out, overwrite_input=overwrite_input, method=method, keepdims=keepdims, interpolation=interpolation) return r if axis is None else Array(r) @@ -2026,8 +2026,8 @@ def quantile(a, q, axis=None, out=None, overwrite_input: bool = False, method: s def nanquantile(a, q, axis=None, out=None, overwrite_input: bool = False, method: str = "linear", keepdims: bool = False, interpolation=None): - a = _remove_brainpy_array(a) - q = _remove_brainpy_array(q) + a = _as_jax_array_(a) + q = _as_jax_array_(q) r = jnp.nanquantile(a=a, q=q, axis=axis, out=out, overwrite_input=overwrite_input, method=method, keepdims=keepdims, interpolation=interpolation) return r if axis is None else Array(r) @@ -2035,8 +2035,8 @@ def nanquantile(a, q, axis=None, out=None, overwrite_input: bool = False, method @wraps(jnp.average) def average(a, axis=None, weights=None, returned=False): - a = _remove_brainpy_array(a) - weights = _remove_brainpy_array(weights) + a = _as_jax_array_(a) + weights = _as_jax_array_(weights) r = jnp.average(a, axis=axis, weights=weights, returned=returned) if axis is None: return r @@ -2048,21 +2048,21 @@ def average(a, axis=None, weights=None, returned=False): @wraps(jnp.mean) def mean(a, axis=None, dtype=None, keepdims=None, where=None): - a = _remove_brainpy_array(a) + a = _as_jax_array_(a) r = jnp.mean(a, axis=axis, dtype=dtype, keepdims=keepdims, where=where) return r if axis is None else Array(r) @wraps(jnp.std) def std(a, axis=None, dtype=None, ddof=0, keepdims=None, where=None): - a = _remove_brainpy_array(a) + a = _as_jax_array_(a) r = jnp.std(a=a, axis=axis, dtype=dtype, ddof=ddof, keepdims=keepdims, where=where) return r if axis is None else Array(r) @wraps(jnp.var) def var(a, axis=None, dtype=None, ddof=0, keepdims=None, where=None): - a = _remove_brainpy_array(a) + a = _as_jax_array_(a) r = jnp.var(a, axis=axis, dtype=dtype, ddof=ddof, keepdims=keepdims, where=where) return r if axis is None else Array(r) @@ -2074,69 +2074,69 @@ def nanmedian(a, axis=None, keepdims=False): @wraps(jnp.nanmean) def nanmean(a, axis=None, dtype=None, keepdims=None, **kwargs): - a = _remove_brainpy_array(a) + a = _as_jax_array_(a) r = jnp.nanmean(a, axis=axis, dtype=dtype, keepdims=keepdims, **kwargs) return r if axis is None else Array(r) @wraps(jnp.nanstd) def nanstd(a, axis=None, dtype=None, ddof=0, keepdims=None, where=None): - a = _remove_brainpy_array(a) + a = _as_jax_array_(a) r = jnp.nanstd(a=a, axis=axis, dtype=dtype, ddof=ddof, keepdims=keepdims, where=where) return r if axis is None else Array(r) @wraps(jnp.nanvar) def nanvar(a, axis=None, dtype=None, ddof=0, keepdims=None, where=None): - a = _remove_brainpy_array(a) + a = _as_jax_array_(a) r = jnp.nanvar(a, axis=axis, dtype=dtype, ddof=ddof, keepdims=keepdims, where=where) return r if axis is None else Array(r) @wraps(jnp.corrcoef) def corrcoef(x, y=None, rowvar=True): - x = _remove_brainpy_array(x) - y = _remove_brainpy_array(y) + x = _as_jax_array_(x) + y = _as_jax_array_(y) return Array(jnp.corrcoef(x, y, rowvar)) @wraps(jnp.correlate) def correlate(a, v, mode='valid', **kwargs): - a = _remove_brainpy_array(a) - v = _remove_brainpy_array(v) + a = _as_jax_array_(a) + v = _as_jax_array_(v) return Array(jnp.correlate(a, v, mode, **kwargs)) @wraps(jnp.cov) def cov(m, y=None, rowvar=True, bias=False, ddof=None, fweights=None, aweights=None): - m = _remove_brainpy_array(m) - y = _remove_brainpy_array(y) - fweights = _remove_brainpy_array(fweights) - aweights = _remove_brainpy_array(aweights) + m = _as_jax_array_(m) + y = _as_jax_array_(y) + fweights = _as_jax_array_(fweights) + aweights = _as_jax_array_(aweights) return Array(jnp.cov(m, y=y, rowvar=rowvar, bias=bias, ddof=ddof, fweights=fweights, aweights=aweights)) @wraps(jnp.histogram) def histogram(a, bins=10, range=None, weights=None, density=None): - a = _remove_brainpy_array(a) - weights = _remove_brainpy_array(weights) + a = _as_jax_array_(a) + weights = _as_jax_array_(weights) hist, bin_edges = jnp.histogram(a=a, bins=bins, range=range, weights=weights, density=density) return Array(hist), Array(bin_edges) @wraps(jnp.bincount) def bincount(x, weights=None, minlength=0, length=None, **kwargs): - x = _remove_brainpy_array(x) - weights = _remove_brainpy_array(weights) + x = _as_jax_array_(x) + weights = _as_jax_array_(weights) res = jnp.bincount(x, weights=weights, minlength=minlength, length=length, **kwargs) return Array(res) @wraps(jnp.digitize) def digitize(x, bins, right=False): - x = _remove_brainpy_array(x) - bins = _remove_brainpy_array(bins) + x = _as_jax_array_(x) + bins = _as_jax_array_(bins) return Array(jnp.digitize(x, bins=bins, right=right)) @@ -2179,49 +2179,49 @@ def kaiser(M, beta): @wraps(jnp.dot) def dot(x1, x2, **kwargs): - x1 = _remove_brainpy_array(x1) - x2 = _remove_brainpy_array(x2) + x1 = _as_jax_array_(x1) + x2 = _as_jax_array_(x2) return Array(jnp.dot(x1, x2, **kwargs)) @wraps(jnp.vdot) def vdot(x1, x2, **kwargs): - x1 = _remove_brainpy_array(x1) - x2 = _remove_brainpy_array(x2) + x1 = _as_jax_array_(x1) + x2 = _as_jax_array_(x2) return Array(jnp.vdot(x1, x2, **kwargs)) @wraps(jnp.inner) def inner(x1, x2, **kwargs): - x1 = _remove_brainpy_array(x1) - x2 = _remove_brainpy_array(x2) + x1 = _as_jax_array_(x1) + x2 = _as_jax_array_(x2) return Array(jnp.inner(x1, x2, **kwargs)) @wraps(jnp.outer) def outer(x1, x2): - x1 = _remove_brainpy_array(x1) - x2 = _remove_brainpy_array(x2) + x1 = _as_jax_array_(x1) + x2 = _as_jax_array_(x2) return Array(jnp.outer(x1, x2)) @wraps(jnp.kron) def kron(x1, x2): - x1 = _remove_brainpy_array(x1) - x2 = _remove_brainpy_array(x2) + x1 = _as_jax_array_(x1) + x2 = _as_jax_array_(x2) return Array(jnp.kron(x1, x2)) @wraps(jnp.matmul) def matmul(x1, x2, **kwargs): - x1 = _remove_brainpy_array(x1) - x2 = _remove_brainpy_array(x2) + x1 = _as_jax_array_(x1) + x2 = _as_jax_array_(x2) return Array(jnp.matmul(x1, x2, **kwargs)) @wraps(jnp.trace) def trace(x, offset=0, axis1=0, axis2=1, dtype=None): - x = _remove_brainpy_array(x) + x = _as_jax_array_(x) return Array(jnp.trace(x, offset=offset, axis1=axis1, axis2=axis2, dtype=dtype)) @@ -2279,15 +2279,15 @@ def can_cast(from_, to, casting=None): True if cast can occur according to the casting rule. """ - from_ = _remove_brainpy_array(from_) - to = _remove_brainpy_array(to) + from_ = _as_jax_array_(from_) + to = _as_jax_array_(to) return jnp.can_cast(from_, to, casting=casting) @wraps(jnp.choose) def choose(a, choices, mode='raise'): - a = _remove_brainpy_array(a) - choices = [_remove_brainpy_array(c) for c in choices] + a = _as_jax_array_(a) + choices = [_as_jax_array_(c) for c in choices] return jnp.choose(a, choices, mode=mode) @@ -2309,7 +2309,7 @@ def fromfunction(function, shape, dtype=float, **kwargs): def fromiter(iterable, dtype, count=-1, *args, **kwargs): - iterable = _remove_brainpy_array(iterable) + iterable = _as_jax_array_(iterable) return asarray(np.fromiter(iterable, dtype=dtype, count=count, *args, **kwargs)) @@ -2321,17 +2321,17 @@ def fromstring(string, dtype=float, count=-1, *, sep): def iscomplexobj(x): - return np.iscomplexobj(_remove_brainpy_array(x)) + return np.iscomplexobj(_as_jax_array_(x)) @wraps(jnp.isneginf) def isneginf(x): - return Array(jnp.isneginf(_remove_brainpy_array(x))) + return Array(jnp.isneginf(_as_jax_array_(x))) @wraps(jnp.isposinf) def isposinf(x): - return Array(jnp.isposinf(_remove_brainpy_array(x))) + return Array(jnp.isposinf(_as_jax_array_(x))) def isrealobj(x): @@ -2343,18 +2343,18 @@ def isrealobj(x): def iterable(x): - return np.iterable(_remove_brainpy_array(x)) + return np.iterable(_as_jax_array_(x)) @wraps(jnp.packbits) def packbits(a, axis: Optional[int] = None, bitorder='big'): - return Array(jnp.packbits(_remove_brainpy_array(a), axis=axis, bitorder=bitorder)) + return Array(jnp.packbits(_as_jax_array_(a), axis=axis, bitorder=bitorder)) @wraps(jnp.piecewise) def piecewise(x, condlist, funclist, *args, **kw): condlist = asarray(condlist, dtype=bool) - return Array(jnp.piecewise(_remove_brainpy_array(x), condlist.value, funclist, *args, **kw)) + return Array(jnp.piecewise(_as_jax_array_(x), condlist.value, funclist, *args, **kw)) printoptions = np.printoptions @@ -2363,31 +2363,31 @@ def piecewise(x, condlist, funclist, *args, **kw): @wraps(jnp.promote_types) def promote_types(a, b): - a = _remove_brainpy_array(a) - b = _remove_brainpy_array(b) + a = _as_jax_array_(a) + b = _as_jax_array_(b) return jnp.promote_types(a, b) @wraps(jnp.ravel_multi_index) def ravel_multi_index(multi_index, dims, mode='raise', order='C'): - multi_index = [_remove_brainpy_array(i) for i in multi_index] + multi_index = [_as_jax_array_(i) for i in multi_index] return Array(jnp.ravel_multi_index(multi_index, dims, mode=mode, order=order)) @wraps(jnp.result_type) def result_type(*args): - args = [_remove_brainpy_array(a) for a in args] + args = [_as_jax_array_(a) for a in args] return jnp.result_type(*args) @wraps(jnp.sort_complex) def sort_complex(a): - return Array(jnp.sort_complex(_remove_brainpy_array(a))) + return Array(jnp.sort_complex(_as_jax_array_(a))) @wraps(jnp.unpackbits) def unpackbits(a, axis: Optional[int] = None, count=None, bitorder='big'): - a = _remove_brainpy_array(a) + a = _as_jax_array_(a) return Array(jnp.unpackbits(a, axis, count=count, bitorder=bitorder)) @@ -2484,8 +2484,8 @@ def place(arr, mask, vals): @wraps(jnp.polydiv) def polydiv(u, v, **kwargs): - u = _remove_brainpy_array(u) - v = _remove_brainpy_array(v) + u = _as_jax_array_(u) + v = _as_jax_array_(v) res = jnp.polydiv(u, v, **kwargs) if isinstance(res, tuple): return tuple(Array(r) for r in res) diff --git a/brainpy/math/object_transform/controls.py b/brainpy/math/object_transform/controls.py index 8578efbbe..6615cc7d0 100644 --- a/brainpy/math/object_transform/controls.py +++ b/brainpy/math/object_transform/controls.py @@ -21,7 +21,6 @@ from brainpy.math.numpy_ops import as_device_array from ._utils import infer_dyn_vars from .base import ObjectTransform -from brainpy.types import PyTree __all__ = [ 'make_loop', @@ -647,7 +646,7 @@ def for_loop( body_fun: Callable, operands: Any, dyn_vars: Union[Variable, Sequence[Variable], Dict[str, Variable]] = None, - out_vars: Optional[PyTree] = None, + out_vars: Optional[Union[Variable, Sequence[Variable], Dict[str, Variable]]] = None, reverse: bool = False, unroll: int = 1, ): diff --git a/brainpy/math/random.py b/brainpy/math/random.py index f2c0ffa68..ffead3a19 100644 --- a/brainpy/math/random.py +++ b/brainpy/math/random.py @@ -4,6 +4,7 @@ from collections import namedtuple from functools import partial from operator import index +from typing import Optional, Union import jax import numpy as np @@ -12,9 +13,9 @@ from jax.experimental.host_callback import call from jax.tree_util import register_pytree_node -from brainpy.math.ndarray import Array, Variable from brainpy.check import jit_error_checking from brainpy.errors import UnsupportedError +from brainpy.math.ndarray import Array, Variable from ._utils import wraps __all__ = [ @@ -402,12 +403,14 @@ class RandomState(Variable): """RandomState that track the random generator state. """ __slots__ = () - def __init__(self, seed_or_key=None, seed=None): + def __init__(self, + seed_or_key: Optional[Union[int, Array, jax.Array, np.ndarray]] = None, + seed: Optional[int] = None): """RandomState constructor. Parameters ---------- - seed_or_key: int, ArrayType, optional + seed_or_key: int, Array, optional It can be an integer for initial seed of the random number generator, or it can be a JAX's PRNKey, which is an array with two elements and `uint32` dtype. @@ -450,6 +453,9 @@ def __repr__(self) -> str: # seed and random key # # ------------------- # + def clone(self): + return type(self)(self.split_key()) + def seed(self, seed_or_key=None, seed=None): """Sets a new random seed. diff --git a/brainpy/tests/test_checking.py b/brainpy/test_check.py similarity index 100% rename from brainpy/tests/test_checking.py rename to brainpy/test_check.py diff --git a/brainpy/test_checkpoints.py b/brainpy/test_checkpoints.py new file mode 100644 index 000000000..40a96afc6 --- /dev/null +++ b/brainpy/test_checkpoints.py @@ -0,0 +1 @@ +# -*- coding: utf-8 -*- diff --git a/brainpy/types.py b/brainpy/types.py index 926c65fe2..a9124e482 100644 --- a/brainpy/types.py +++ b/brainpy/types.py @@ -2,9 +2,10 @@ from typing import TypeVar, Tuple -import numpy as np import jax.numpy as jnp +import numpy as np +from brainpy.math.ndarray import Array, Variable, TrainVar __all__ = [ 'ArrayType', 'Parameter', 'PyTree', @@ -15,7 +16,7 @@ # data Parameter = TypeVar('Parameter', float, int, jnp.ndarray, 'Array', 'Variable') # noqa -ArrayType = TypeVar('ArrayType', 'Array', 'Variable', 'TrainVar', jnp.ndarray, np.ndarray) # noqa +ArrayType = TypeVar('ArrayType', Array, Variable, TrainVar, jnp.ndarray, np.ndarray) # noqa Array = ArrayType # noqa PyTree = TypeVar('PyTree') # noqa From 3973f3c347c56e2c34fea4695c8c93de8a640cb5 Mon Sep 17 00:00:00 2001 From: chaoming Date: Tue, 27 Dec 2022 14:56:14 +0800 Subject: [PATCH 5/6] enable x64 setting in an environment --- brainpy/math/environment.py | 45 +++++++++++++++++++++++++++++++++---- 1 file changed, 41 insertions(+), 4 deletions(-) diff --git a/brainpy/math/environment.py b/brainpy/math/environment.py index c1400cd71..fa7b9c763 100644 --- a/brainpy/math/environment.py +++ b/brainpy/math/environment.py @@ -9,7 +9,7 @@ import warnings from typing import Any, Callable, TypeVar, cast -from jax import dtypes, config, numpy as jnp, devices +from jax import config, numpy as jnp, devices from jax.lib import xla_bridge from . import modes @@ -329,6 +329,7 @@ def clone(self): def set_environment( mode: modes.Mode = None, dt: float = None, + x64: bool = None, complex_: type = None, float_: type = None, int_: type = None, @@ -342,6 +343,8 @@ def set_environment( The computing mode. dt: float The numerical integration precision. + x64: bool + Enable x64 computation. complex_: type The complex data type. float_ @@ -359,6 +362,10 @@ def set_environment( assert isinstance(mode, modes.Mode), f'"mode" must a {modes.Mode}.' set_mode(mode) + if x64 is not None: + assert isinstance(x64, bool), f'"x64" must be a bool.' + set_x64(x64) + if float_ is not None: assert isinstance(float_, type), '"float_" must a float.' set_float(float_) @@ -402,8 +409,9 @@ class environment(_DecoratorContextManager): def __init__( self, - dt: float = None, mode: modes.Mode = None, + dt: float = None, + x64: bool = None, complex_: type = None, float_: type = None, int_: type = None, @@ -412,6 +420,7 @@ def __init__( super().__init__() self.old_dt = get_dt() self.old_mode = get_mode() + self.old_x64 = config.read("jax_enable_x64") self.old_int = get_int() self.old_bool = get_bool() self.old_float = get_float() @@ -421,6 +430,8 @@ def __init__( assert isinstance(dt, float), '"dt" must a float.' if mode is not None: assert isinstance(mode, modes.Mode), f'"mode" must a {modes.Mode}.' + if x64 is not None: + assert isinstance(x64, bool), f'"x64" must be a bool.' if float_ is not None: assert isinstance(float_, type), '"float_" must a float.' if int_ is not None: @@ -431,6 +442,7 @@ def __init__( assert isinstance(complex_, type), '"complex_" must a type.' self.dt = dt self.mode = mode + self.x64 = x64 self.complex_ = complex_ self.float_ = float_ self.int_ = int_ @@ -439,6 +451,7 @@ def __init__( def __enter__(self) -> 'environment': if self.dt is not None: set_dt(self.dt) if self.mode is not None: set_mode(self.mode) + if self.x64 is not None: set_x64(self.x64) if self.float_ is not None: set_float(self.float_) if self.int_ is not None: set_int(self.int_) if self.complex_ is not None: set_complex(self.complex_) @@ -448,6 +461,7 @@ def __enter__(self) -> 'environment': def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: if self.dt is not None: set_dt(self.old_dt) if self.mode is not None: set_mode(self.old_mode) + if self.x64 is not None: set_x64(self.old_x64) if self.int_ is not None: set_int(self.old_int) if self.float_ is not None: set_float(self.old_float) if self.complex_ is not None: set_complex(self.old_complex) @@ -456,6 +470,7 @@ def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: def clone(self): return self.__class__(dt=self.dt, mode=self.mode, + x64=self.x64, bool_=self.bool_, complex_=self.complex_, float_=self.float_, @@ -468,6 +483,7 @@ class training_environment(environment): This is a short-cut context setting for an environment with the training mode. It is equivalent to:: + >>> import brainpy.math as bm >>> with bm.environment(mode=bm.training_mode): >>> pass @@ -476,11 +492,17 @@ class training_environment(environment): def __init__(self, dt: float = None, + x64: bool = None, complex_: type = None, float_: type = None, int_: type = None, bool_: type = None): - super().__init__(dt=dt, complex_=complex_, float_=float_, int_=int_, bool_=bool_, + super().__init__(dt=dt, + x64=x64, + complex_=complex_, + float_=float_, + int_=int_, + bool_=bool_, mode=modes.TrainingMode()) @@ -490,6 +512,7 @@ class batching_environment(environment): This is a short-cut context setting for an environment with the batching mode. It is equivalent to:: + >>> import brainpy.math as bm >>> with bm.environment(mode=bm.batching_mode): >>> pass @@ -498,11 +521,17 @@ class batching_environment(environment): def __init__(self, dt: float = None, + x64: bool = None, complex_: type = None, float_: type = None, int_: type = None, bool_: type = None): - super().__init__(dt=dt, complex_=complex_, float_=float_, int_=int_, bool_=bool_, + super().__init__(dt=dt, + x64=x64, + complex_=complex_, + float_=float_, + int_=int_, + bool_=bool_, mode=modes.BatchingMode()) @@ -520,6 +549,14 @@ def disable_x64(): set_complex(jnp.complex64) +def set_x64(enable: bool): + assert isinstance(enable, bool) + if enable: + enable_x64() + else: + disable_x64() + + def set_platform(platform: str): """ Changes platform to CPU, GPU, or TPU. This utility only takes From d37c113847526f07d1ca1450815f5540ce50c000 Mon Sep 17 00:00:00 2001 From: chaoming Date: Tue, 27 Dec 2022 14:57:35 +0800 Subject: [PATCH 6/6] unified gradient apis --- brainpy/math/object_transform/autograd_old.py | 1001 +++++++++++++++++ 1 file changed, 1001 insertions(+) create mode 100644 brainpy/math/object_transform/autograd_old.py diff --git a/brainpy/math/object_transform/autograd_old.py b/brainpy/math/object_transform/autograd_old.py new file mode 100644 index 000000000..b726201f7 --- /dev/null +++ b/brainpy/math/object_transform/autograd_old.py @@ -0,0 +1,1001 @@ +# -*- coding: utf-8 -*- + +from functools import partial +from typing import Union, Callable, Dict, Sequence, Any + +import jax +import numpy as np +from jax import linear_util, dtypes, vmap, numpy as jnp +from jax._src.api import (_vjp, _jvp, + _check_callable, + _check_output_dtype_jacrev, _check_input_dtype_jacrev, + _check_output_dtype_jacfwd, _check_input_dtype_jacfwd, ) +from jax.api_util import argnums_partial +from jax.errors import UnexpectedTracerError +from jax.tree_util import tree_flatten, tree_unflatten, tree_map, tree_transpose, tree_structure +from jax.util import safe_map + +from brainpy import errors, tools +from brainpy.base import get_unique_name, ArrayCollector +from brainpy.math.ndarray import Array, add_context, del_context +from ._utils import infer_dyn_vars +from .base import ObjectTransform + +__all__ = [ + 'grad', # gradient of scalar function + 'vector_grad', # gradient of vector/matrix/... + 'jacobian', 'jacrev', 'jacfwd', # gradient of jacobian + 'hessian', # gradient of hessian +] + + +class GradientFunTransform(ObjectTransform): + _excluded_vars = ('_origin_fun',) + + def __init__( + self, + grad_func: Callable, + dyn_vars: Any, + grad_vars: Any, + name: str = None, + origin_fun=None + ): + super().__init__(name=name) + + self._origin_fun = origin_fun + self._f = grad_func + self.register_implicit_vars(dyn_vars, grad_vars) + + def __call__(self, *args, **kwargs): + return self._f(*args, **kwargs) + + def __repr__(self): + name = self.__class__.__name__ + f = tools.repr_object(self._origin_fun) + f = tools.repr_context(f, " " * (len(name) + 6)) + return f'{name}(target={f})' + + +class GradientTransform(ObjectTransform): + _excluded_vars = ('_origin_fun',) + + def __init__( + self, + grad_func: Callable, + grad_tree, + grad_vars, + dyn_vars, + argnums, + return_value: bool, + has_aux: bool, + name: str = None, + origin_fun=None + ): + super().__init__(name=name) + + self.register_implicit_vars(dyn_vars, grad_vars) + self._grad_func = grad_func + self._grad_tree = grad_tree + self._grad_vars = grad_vars + self._dyn_vars = dyn_vars + self._argnums = argnums + self._return_value = return_value + self._has_aux = has_aux + self._origin_fun = origin_fun + + self.register_implicit_vars(dyn_vars, grad_vars) + + def __repr__(self): + name = self.__class__.__name__ + f = tools.repr_object(self._origin_fun) + f = tools.repr_context(f, " " * (len(name) + 6)) + format_ref = (f'{name}(target={f}, \n' + + f'{" " * len(name)} num_of_grad_vars={len(self._grad_vars)}, \n' + f'{" " * len(name)} num_of_dyn_vars={len(self._dyn_vars)})') + return format_ref + + def __call__(self, *args, **kwargs): + old_grad_vs = [v.value for v in self._grad_vars] + old_dyn_vs = [v.value for v in self._dyn_vars] + try: + add_context(self.name) + grads, (outputs, new_grad_vs, new_dyn_vs) = self._grad_func(old_grad_vs, + old_dyn_vs, + *args, + **kwargs) + del_context(self._name) + except UnexpectedTracerError as e: + del_context(self._name) + for v, d in zip(self._grad_vars, old_grad_vs): v._value = d + for v, d in zip(self._dyn_vars, old_dyn_vs): v._value = d + raise errors.JaxTracerError(variables=self._dyn_vars + self._grad_vars) from e + except Exception as e: + del_context(self._name) + for v, d in zip(self._grad_vars, old_grad_vs): v._value = d + for v, d in zip(self._dyn_vars, old_dyn_vs): v._value = d + raise e + else: + for v, d in zip(self._grad_vars, new_grad_vs): v._value = d + for v, d in zip(self._dyn_vars, new_dyn_vs): v._value = d + + # check returned grads + if len(self._grad_vars) == 0: + grads = grads[1] if isinstance(self._argnums, int) else grads[1:] + else: + var_grads = self._grad_tree.unflatten(grads[0]) + if self._argnums is None: + grads = var_grads + else: + arg_grads = grads[1] if isinstance(self._argnums, int) else grads[1:] + grads = (var_grads, arg_grads) + + # check returned value + if self._return_value: + # check aux + if self._has_aux: + return grads, outputs[0], outputs[1] + else: + return grads, outputs + else: + # check aux + if self._has_aux: + return grads, outputs[1] + else: + return grads + + +def _make_cls_call_func(grad_func, grad_tree, grad_vars, dyn_vars, + argnums, return_value, has_aux): + name = get_unique_name('_brainpy_object_oriented_grad_') + + # outputs + def call_func(*args, **kwargs): + old_grad_vs = [v.value for v in grad_vars] + old_dyn_vs = [v.value for v in dyn_vars] + try: + add_context(name) + grads, (outputs, new_grad_vs, new_dyn_vs) = grad_func(old_grad_vs, + old_dyn_vs, + *args, + **kwargs) + del_context(name) + except UnexpectedTracerError as e: + del_context(name) + for v, d in zip(grad_vars, old_grad_vs): v._value = d + for v, d in zip(dyn_vars, old_dyn_vs): v._value = d + raise errors.JaxTracerError(variables=dyn_vars + grad_vars) from e + except Exception as e: + del_context(name) + for v, d in zip(grad_vars, old_grad_vs): v._value = d + for v, d in zip(dyn_vars, old_dyn_vs): v._value = d + raise e + for v, d in zip(grad_vars, new_grad_vs): v._value = d + for v, d in zip(dyn_vars, new_dyn_vs): v._value = d + + # check returned grads + if len(grad_vars) == 0: + grads = grads[1] if isinstance(argnums, int) else grads[1:] + else: + var_grads = grad_tree.unflatten(grads[0]) + if argnums is None: + grads = var_grads + else: + arg_grads = grads[1] if isinstance(argnums, int) else grads[1:] + grads = (var_grads, arg_grads) + + # check returned value + if return_value: + # check aux + if has_aux: + return grads, outputs[0], outputs[1] + else: + return grads, outputs + else: + # check aux + if has_aux: + return grads, outputs[1] + else: + return grads + + return call_func + + +def _check_vars(variables): + if variables is None: + vars, tree = tree_flatten(variables, is_leaf=lambda a: isinstance(a, Array)) + return vars, tree + if isinstance(variables, dict): + variables = dict(variables) + elif isinstance(variables, (list, tuple)): + variables = tuple(variables) + elif isinstance(variables, Array): + pass + else: + raise ValueError + vars, tree = tree_flatten(variables, is_leaf=lambda a: isinstance(a, Array)) + for v in vars: + if not isinstance(v, Array): + raise ValueError(f'"dyn_vars" and "grad_vars" only supports dict ' + f'of Array, but got {type(v)}: {v}') + return vars, tree + + +def _grad_checking(func: Callable, + dyn_vars: Union[Dict, Sequence], + grad_vars: Union[Dict, Sequence]): + # check function + if not callable(func): + raise ValueError(f'Must be a callable object. But we got {func}') + + # check "vars", make sure it is an instance of ArrayCollector + dyn_vars, _ = _check_vars(dyn_vars) + grad_vars, grad_tree = _check_vars(grad_vars) + + # check the duplicate in "dyn_vars" and "grad_vars" + dyn_vars = tuple(ArrayCollector.from_other(dyn_vars).unique().values()) + new_dyn_vars = [] + _dyn_var_ids = set([id(v) for v in grad_vars]) + for v in dyn_vars: + if id(v) not in _dyn_var_ids: + new_dyn_vars.append(v) + _dyn_var_ids.add(id(v)) + return tuple(new_dyn_vars), grad_vars, grad_tree + + +def _cls_grad(func, grad_vars, dyn_vars, argnums, has_aux=False, + holomorphic=False, allow_int=False, reduce_axes=()): + # parameters + assert isinstance(dyn_vars, (tuple, list)) # tuple/list of Variable + assert isinstance(grad_vars, (tuple, list)) # tuple/list of Variable + assert isinstance(argnums, (tuple, list)) # tuple/list of int + + # gradient functions + if has_aux: + @partial(jax.grad, argnums=argnums, has_aux=True, holomorphic=holomorphic, + allow_int=allow_int, reduce_axes=reduce_axes) + def grad_func(grad_values, dyn_values, *args, **kwargs): + for v, d in zip(dyn_vars, dyn_values): v._value = d + for v, d in zip(grad_vars, grad_values): v._value = d + # Users should return the auxiliary data like:: + # >>> # 1. example of return one data + # >>> return scalar_loss, data + # >>> # 2. example of return multiple data + # >>> return scalar_loss, (data1, data2, ...) + outputs = func(*args, **kwargs) + # outputs: [0] is the value for gradient, + # [1] is other values for return + output = outputs[0].value if isinstance(outputs[0], Array) else outputs[0] + return output, (outputs, [v.value for v in grad_vars], [v.value for v in dyn_vars]) + else: + @partial(jax.grad, + argnums=argnums, has_aux=True, holomorphic=holomorphic, + allow_int=allow_int, reduce_axes=reduce_axes) + def grad_func(grad_values, dyn_values, *args, **kwargs): + for v, d in zip(dyn_vars, dyn_values): v._value = d + for v, d in zip(grad_vars, grad_values): v._value = d + # Users should return the scalar value like this:: + # >>> return scalar_loss + output = func(*args, **kwargs) + output2 = output.value if isinstance(output, Array) else output + return output2, (output, [v.value for v in grad_vars], [v.value for v in dyn_vars]) + + return grad_func + + +def grad( + func, grad_vars=None, dyn_vars=None, argnums=None, holomorphic=False, + allow_int=False, reduce_axes=(), has_aux=None, return_value=False, + auto_infer=True +) -> ObjectTransform: + """Automatic gradient computation for functions or class objects. + + This gradient function only support scalar return. It creates a function + which evaluates the gradient of ``func``. + + It's worthy to note that the returns are different for different argument settings (where ``arg_grads`` refers + to the gradients of "argnums", and ``var_grads`` refers to the gradients of "grad_vars"). + + 1. When "grad_vars" is None + - "has_aux=False" + "return_value=False" => ``arg_grads``. + - "has_aux=True" + "return_value=False" => ``(arg_grads, aux_data)``. + - "has_aux=False" + "return_value=True" => ``(arg_grads, loss_value)``. + - "has_aux=True" + "return_value=True" => ``(arg_grads, loss_value, aux_data)``. + 2. When "grad_vars" is not None and "argnums" is None + - "has_aux=False" + "return_value=False" => ``var_grads``. + - "has_aux=True" + "return_value=False" => ``(var_grads, aux_data)``. + - "has_aux=False" + "return_value=True" => ``(var_grads, loss_value)``. + - "has_aux=True" + "return_value=True" => ``(var_grads, loss_value, aux_data)``. + 3. When "grad_vars" is not None and "argnums" is not None + - "has_aux=False" + "return_value=False" => ``(var_grads, arg_grads)``. + - "has_aux=True" + "return_value=False" => ``((var_grads, arg_grads), aux_data)``. + - "has_aux=False" + "return_value=True" => ``((var_grads, arg_grads), loss_value)``. + - "has_aux=True" + "return_value=True" => ``((var_grads, arg_grads), loss_value, aux_data)``. + + Let's see some examples below. + + Before start, let's figure out what should be provided as ``grad_vars``? + And, what should be labeled in ``argnums``? + Take the following codes as example: + + >>> import brainpy as bp + >>> import brainpy.math as bm + >>> + >>> class Example(bp.BrainPyObject): + >>> def __init__(self): + >>> super(Example, self).__init__() + >>> self.x = bm.TrainVar(bm.zeros(1)) + >>> self.y = bm.random.rand(10) + >>> def __call__(self, z, v): + >>> t1 = self.x * self.y.sum() + >>> t2 = bm.tanh(z * v + t1) + >>> return t2.mean() + >>> + >>> # This code is equivalent to the following function: + >>> + >>> x = bm.TrainVar(bm.zeros(1)) + >>> y = bm.random.rand(10) + >>> def f(z, v): + >>> t1 = x * y.sum() + >>> t2 = bm.tanh(z * v + t1) + >>> return t2.mean() + + Generally speaking, all gradient variables which not provided in arguments should be + labeled as ``grad_vars``, while all gradient variables provided in the function arguments + should be declared in ``argnums``. + In above codes, we try to take gradients of ``self.x`` and arguments ``z`` and ``v``, we should + call ``brainpy.math.grad`` as: + + >>> f = Example() + >>> f_grad = bm.grad(f, grad_vars=f.x, argnums=(0, 1)) + + + Examples + -------- + + Grad for a pure function: + + >>> import brainpy as bp + >>> grad_tanh = grad(bp.math.tanh) + >>> print(grad_tanh(0.2)) + 0.961043 + + Parameters + ---------- + func : callable, function, BrainPyObject + Function to be differentiated. Its arguments at positions specified by + ``argnums`` should be arrays, scalars, or standard Python containers. + Argument arrays in the positions specified by ``argnums`` must be of + inexact (i.e., floating-point or complex) type. It should return a scalar + (which includes arrays with shape ``()`` but not arrays with shape ``(1,)`` etc.) + dyn_vars : optional, ArrayType, sequence of ArrayType, dict + The dynamically changed variables used in ``func``. + grad_vars : optional, ArrayType, sequence of ArrayType, dict + The variables in ``func`` to take their gradients. + argnums : optional, integer or sequence of integers + Specifies which positional argument(s) to differentiate with respect to (default 0). + has_aux: optional, bool + Indicates whether ``fun`` returns a pair where the + first element is considered the output of the mathematical function to be + differentiated and the second element is auxiliary data. Default False. + return_value : bool + Whether return the loss value. + holomorphic: optional, bool + Indicates whether ``fun`` is promised to be + holomorphic. If True, inputs and outputs must be complex. Default False. + allow_int: optional, bool + Whether to allow differentiating with + respect to integer valued inputs. The gradient of an integer input will + have a trivial vector-space dtype (float0). Default False. + reduce_axes: optional, tuple of int + tuple of axis names. If an axis is listed here, and + ``fun`` implicitly broadcasts a value over that axis, the backward pass + will perform a ``psum`` of the corresponding gradient. Otherwise, the + gradient will be per-example over named axes. For example, if ``'batch'`` + is a named batch axis, ``grad(f, reduce_axes=('batch',))`` will create a + function that computes the total gradient while ``grad(f)`` will create + one that computes the per-example gradient. + auto_infer: bool + Automatically infer all ``Variable`` instances used in the target. + + + Returns + ------- + func : ObjectTransform + A function with the same arguments as ``fun``, that evaluates the gradient + of ``fun``. If ``argnums`` is an integer then the gradient has the same + shape and type as the positional argument indicated by that integer. If + argnums is a tuple of integers, the gradient is a tuple of values with the + same shapes and types as the corresponding arguments. If ``has_aux`` is True + then a pair of (gradient, auxiliary_data) is returned. + """ + + if dyn_vars is None: + dyn_vars = infer_dyn_vars(func) if auto_infer else dict() + + dyn_vars, grad_vars, grad_tree = _grad_checking(func, dyn_vars, grad_vars) + # dyn_vars -> ArrayCollector + # grad_vars -> ArrayCollector + has_aux = False if has_aux is None else has_aux + + # gradient + if len(dyn_vars) == 0 and len(grad_vars) == 0: + argnums = 0 if argnums is None else argnums + if return_value: + grad_func = jax.value_and_grad(fun=func, + argnums=argnums, + has_aux=has_aux, + holomorphic=holomorphic, + allow_int=allow_int, + reduce_axes=reduce_axes) + + def call_func(*args, **kwargs): + result = grad_func(*args, **kwargs) + if has_aux: + (ans, aux), g = result + return g, ans, aux + else: + ans, g = result + return g, ans + + return GradientFunTransform(call_func, dyn_vars=dyn_vars, grad_vars=grad_vars, origin_fun=func) + + else: + # has_aux = True: g, aux + # has_aux = False: g + call_func = jax.grad(fun=func, + argnums=argnums, + has_aux=has_aux, + holomorphic=holomorphic, + allow_int=allow_int, + reduce_axes=reduce_axes) + return GradientFunTransform(call_func, dyn_vars=dyn_vars, grad_vars=grad_vars, origin_fun=func) + + else: + # argnums + _argnums, _ = tree_flatten(argnums) + _argnums = tuple(a + 2 for a in _argnums) + if argnums is None and len(grad_vars) == 0: + raise errors.MathError('We detect no require to compute gradients because ' + '"grad_vars" is None and "argnums" is also None. ' + 'Please provide one of them.') + # computation + grad_func = _cls_grad(func=func, + grad_vars=grad_vars, + dyn_vars=dyn_vars, + argnums=(0,) + _argnums, + has_aux=has_aux, + holomorphic=holomorphic, + allow_int=allow_int, + reduce_axes=reduce_axes) + + return GradientTransform(grad_func=grad_func, + grad_tree=grad_tree, + grad_vars=grad_vars, + dyn_vars=dyn_vars, + argnums=argnums, + return_value=return_value, + has_aux=has_aux, + origin_fun=func) + + +def _unravel_array_into_pytree(pytree, axis, arr, is_leaf=None): + leaves, treedef = tree_flatten(pytree, is_leaf=is_leaf) + axis = axis % arr.ndim + shapes = [arr.shape[:axis] + np.shape(l) + arr.shape[axis + 1:] for l in leaves] + parts = arr.split(np.cumsum(safe_map(np.size, leaves[:-1])), axis) + reshaped_parts = [x.reshape(shape) for x, shape in zip(parts, shapes)] + return tree_unflatten(treedef, reshaped_parts, ) + + +def _std_basis(pytree): + leaves, _ = tree_flatten(pytree) + ndim = sum(safe_map(np.size, leaves)) + dtype = dtypes.result_type(*leaves) + flat_basis = jax.numpy.eye(ndim, dtype=dtype) + return _unravel_array_into_pytree(pytree, 1, flat_basis) + + +_isleaf = lambda x: isinstance(x, Array) + + +def _jacrev(fun, argnums=0, holomorphic=False, allow_int=False, has_aux=False, return_value=False): + _check_callable(fun) + + def jacfun(*args, **kwargs): + f = linear_util.wrap_init(fun, kwargs) + f_partial, dyn_args = argnums_partial(f, argnums, args, require_static_args_hashable=False) + tree_map(partial(_check_input_dtype_jacrev, holomorphic, allow_int), dyn_args) + if has_aux: + y, pullback, aux = _vjp(f_partial, *dyn_args, has_aux=True) + else: + y, pullback = _vjp(f_partial, *dyn_args, has_aux=False) + tree_map(partial(_check_output_dtype_jacrev, holomorphic), y) + jac = vmap(pullback)(_std_basis(y)) + jac = jac[0] if isinstance(argnums, int) else jac + example_args = dyn_args[0] if isinstance(argnums, int) else dyn_args + jac_tree = tree_map(partial(_unravel_array_into_pytree, y, 0, is_leaf=_isleaf), jac, is_leaf=_isleaf) + jac = tree_transpose(tree_structure(example_args), tree_flatten(y, is_leaf=_isleaf)[1], jac_tree) + if return_value: + return (jac, y, aux) if has_aux else (jac, y) + else: + return (jac, aux) if has_aux else jac + + return GradientFunTransform(jacfun, dyn_vars=(), grad_vars=(), origin_fun=fun) + + +def _cls_jacrev(func, grad_vars, dyn_vars, argnums, + holomorphic=False, allow_int=False, has_aux=False): + # parameters + assert isinstance(dyn_vars, (tuple, list)) # tuple/list of Variable + assert isinstance(grad_vars, (tuple, list)) # tuple/list of Variable + assert isinstance(argnums, (tuple, list)) # tuple/list of int + + # final functions + if has_aux: + @partial(_jacrev, argnums=argnums, holomorphic=holomorphic, + allow_int=allow_int, has_aux=True) + def grad_func(grad_values, dyn_values, *args, **kwargs): + for v, d in zip(dyn_vars, dyn_values): v._value = d + for v, d in zip(grad_vars, grad_values): v._value = d + # outputs: [0] is the value for gradient, + # [1] is other values for return + outputs = func(*args, **kwargs) + output = outputs[0].value if isinstance(outputs[0], Array) else outputs[0] + return output, (outputs, [v.value for v in grad_vars], [v.value for v in dyn_vars]) + + else: + @partial(_jacrev, argnums=argnums, holomorphic=holomorphic, + allow_int=allow_int, has_aux=True) + def grad_func(grad_values, dyn_values, *args, **kwargs): + for v, d in zip(dyn_vars, dyn_values): v._value = d + for v, d in zip(grad_vars, grad_values): v._value = d + outputs = func(*args, **kwargs) + output = outputs.value if isinstance(outputs, Array) else outputs + return output, (outputs, [v.value for v in grad_vars], [v.value for v in dyn_vars]) + + return grad_func + + +def jacrev( + func, grad_vars=None, dyn_vars=None, argnums=None, holomorphic=False, + allow_int=False, has_aux=None, return_value=False, + auto_infer=True +) -> ObjectTransform: + """Extending automatic Jacobian (reverse-mode) of ``func`` to classes. + + This function extends the JAX official ``jacrev`` to make automatic jacobian + computation on functions and class functions. Moreover, it supports returning + value ("return_value") and returning auxiliary data ("has_aux"). + + Same as `brainpy.math.grad <./brainpy.math.autograd.grad.html>`_, the returns are + different for different argument settings in ``brainpy.math.jacrev``. + + 1. When "grad_vars" is None + - "has_aux=False" + "return_value=False" => ``arg_grads``. + - "has_aux=True" + "return_value=False" => ``(arg_grads, aux_data)``. + - "has_aux=False" + "return_value=True" => ``(arg_grads, loss_value)``. + - "has_aux=True" + "return_value=True" => ``(arg_grads, loss_value, aux_data)``. + 2. When "grad_vars" is not None and "argnums" is None + - "has_aux=False" + "return_value=False" => ``var_grads``. + - "has_aux=True" + "return_value=False" => ``(var_grads, aux_data)``. + - "has_aux=False" + "return_value=True" => ``(var_grads, loss_value)``. + - "has_aux=True" + "return_value=True" => ``(var_grads, loss_value, aux_data)``. + 3. When "grad_vars" is not None and "argnums" is not None + - "has_aux=False" + "return_value=False" => ``(var_grads, arg_grads)``. + - "has_aux=True" + "return_value=False" => ``((var_grads, arg_grads), aux_data)``. + - "has_aux=False" + "return_value=True" => ``((var_grads, arg_grads), loss_value)``. + - "has_aux=True" + "return_value=True" => ``((var_grads, arg_grads), loss_value, aux_data)``. + + + Parameters + ---------- + func: Function whose Jacobian is to be computed. + dyn_vars : optional, ArrayType, sequence of ArrayType, dict + The dynamically changed variables used in ``func``. + grad_vars : optional, ArrayType, sequence of ArrayType, dict + The variables in ``func`` to take their gradients. + has_aux: optional, bool + Indicates whether ``fun`` returns a pair where the + first element is considered the output of the mathematical function to be + differentiated and the second element is auxiliary data. Default False. + return_value : bool + Whether return the loss value. + argnums: Optional, integer or sequence of integers. Specifies which + positional argument(s) to differentiate with respect to (default ``0``). + holomorphic: Optional, bool. Indicates whether ``fun`` is promised to be + holomorphic. Default False. + allow_int: Optional, bool. Whether to allow differentiating with + respect to integer valued inputs. The gradient of an integer input will + have a trivial vector-space dtype (float0). Default False. + auto_infer: bool + Automatically infer all ``Variable`` instance. + + Returns + ------- + fun: ObjectTransform + The transformed object. + """ + if dyn_vars is None: + dyn_vars = infer_dyn_vars(func) if auto_infer else dict() + + dyn_vars, grad_vars, grad_tree = _grad_checking(func, dyn_vars, grad_vars) + has_aux = False if has_aux is None else has_aux + + if (len(dyn_vars) == 0) and (len(grad_vars) == 0): + argnums = 0 if argnums is None else argnums + return _jacrev(fun=func, + argnums=argnums, + holomorphic=holomorphic, + allow_int=allow_int, + has_aux=has_aux, + return_value=return_value) + else: + _argnums, _ = tree_flatten(argnums) + _argnums = tuple(a + 2 for a in _argnums) + if argnums is None and len(grad_vars) == 0: + raise errors.MathError('We detect no require to compute gradients because ' + '"grad_vars" is None and "argnums" is also None. ' + 'Please provide one of them.') + # computation + grad_func = _cls_jacrev(func=func, + grad_vars=grad_vars, + dyn_vars=dyn_vars, + argnums=(0,) + _argnums, + has_aux=has_aux, + holomorphic=holomorphic, + allow_int=allow_int) + + return GradientTransform(grad_func=grad_func, + grad_tree=grad_tree, + grad_vars=grad_vars, + dyn_vars=dyn_vars, + argnums=argnums, + return_value=return_value, + has_aux=has_aux, + origin_fun=func) + + +jacobian = jacrev + + +def _jacfwd(fun, argnums=0, holomorphic=False, has_aux=False, return_value=False): + _check_callable(fun) + if has_aux and jax.__version__ < '0.2.28': + raise NotImplementedError(f'"has_aux" only supported in jax>=0.2.28, but we detect ' + f'the current jax version is {jax.__version__}') + + def jacfun(*args, **kwargs): + f = linear_util.wrap_init(fun, kwargs) + f_partial, dyn_args = argnums_partial(f, argnums, args, require_static_args_hashable=False) + tree_map(partial(_check_input_dtype_jacfwd, holomorphic), dyn_args) + if has_aux: + pushfwd = partial(_jvp, f_partial, dyn_args, has_aux=True) + y, jac, aux = vmap(pushfwd, out_axes=(None, -1, None))(_std_basis(dyn_args)) + else: + pushfwd = partial(_jvp, f_partial, dyn_args) + y, jac = vmap(pushfwd, out_axes=(None, -1))(_std_basis(dyn_args)) + tree_map(partial(_check_output_dtype_jacfwd, holomorphic), y) + example_args = dyn_args[0] if isinstance(argnums, int) else dyn_args + jac = tree_map(partial(_unravel_array_into_pytree, example_args, -1, is_leaf=_isleaf), jac, is_leaf=_isleaf) + if return_value: + return (jac, y, aux) if has_aux else (jac, y) + else: + return (jac, aux) if has_aux else jac + + return GradientFunTransform(jacfun, dyn_vars=(), grad_vars=(), origin_fun=fun) + + +def _cls_jacfwd(func, grad_vars, dyn_vars, argnums, holomorphic=False, has_aux=False): + # parameters + assert isinstance(dyn_vars, (tuple, list)) # tuple/list of Variable + assert isinstance(grad_vars, (tuple, list)) # tuple/list of Variable + assert isinstance(argnums, (tuple, list)) # tuple/list of int + + # final functions + if has_aux: + @partial(_jacfwd, + argnums=argnums, holomorphic=holomorphic, has_aux=True) + def grad_func(grad_values, dyn_values, *args, **kwargs): + for v, d in zip(dyn_vars, dyn_values): v._value = d + for v, d in zip(grad_vars, grad_values): v._value = d + # outputs: [0] is the value for gradient, + # [1] is other values for return + outputs = func(*args, **kwargs) + output = outputs[0].value if isinstance(outputs[0], Array) else outputs[0] + return output, (outputs, [v.value for v in grad_vars], [v.value for v in dyn_vars]) + + else: + @partial(_jacfwd, + argnums=argnums, holomorphic=holomorphic, has_aux=True) + def grad_func(grad_values, dyn_values, *args, **kwargs): + for v, d in zip(dyn_vars, dyn_values): v._value = d + for v, d in zip(grad_vars, grad_values): v._value = d + outputs = func(*args, **kwargs) + output = outputs.value if isinstance(outputs, Array) else outputs + return output, (outputs, [v.value for v in grad_vars], [v.value for v in dyn_vars]) + + return grad_func + + +def jacfwd( + func, grad_vars=None, dyn_vars=None, argnums=None, holomorphic=False, + has_aux=None, return_value=False, auto_infer=True +) -> ObjectTransform: + """Extending automatic Jacobian (forward-mode) of ``func`` to classes. + + This function extends the JAX official ``jacfwd`` to make automatic jacobian + computation on functions and class functions. Moreover, it supports returning + value ("return_value") and returning auxiliary data ("has_aux"). + + Same as `brainpy.math.grad <./brainpy.math.autograd.grad.html>`_, the returns are + different for different argument settings in ``brainpy.math.jacfwd``. + + 1. When "grad_vars" is None + - "has_aux=False" + "return_value=False" => ``arg_grads``. + - "has_aux=True" + "return_value=False" => ``(arg_grads, aux_data)``. + - "has_aux=False" + "return_value=True" => ``(arg_grads, loss_value)``. + - "has_aux=True" + "return_value=True" => ``(arg_grads, loss_value, aux_data)``. + 2. When "grad_vars" is not None and "argnums" is None + - "has_aux=False" + "return_value=False" => ``var_grads``. + - "has_aux=True" + "return_value=False" => ``(var_grads, aux_data)``. + - "has_aux=False" + "return_value=True" => ``(var_grads, loss_value)``. + - "has_aux=True" + "return_value=True" => ``(var_grads, loss_value, aux_data)``. + 3. When "grad_vars" is not None and "argnums" is not None + - "has_aux=False" + "return_value=False" => ``(var_grads, arg_grads)``. + - "has_aux=True" + "return_value=False" => ``((var_grads, arg_grads), aux_data)``. + - "has_aux=False" + "return_value=True" => ``((var_grads, arg_grads), loss_value)``. + - "has_aux=True" + "return_value=True" => ``((var_grads, arg_grads), loss_value, aux_data)``. + + Parameters + ---------- + func: Function whose Jacobian is to be computed. + dyn_vars : optional, ArrayType, sequence of ArrayType, dict + The dynamically changed variables used in ``func``. + grad_vars : optional, ArrayType, sequence of ArrayType, dict + The variables in ``func`` to take their gradients. + has_aux: optional, bool + Indicates whether ``fun`` returns a pair where the + first element is considered the output of the mathematical function to be + differentiated and the second element is auxiliary data. Default False. + return_value : bool + Whether return the loss value. + argnums: Optional, integer or sequence of integers. Specifies which + positional argument(s) to differentiate with respect to (default ``0``). + holomorphic: Optional, bool. Indicates whether ``fun`` is promised to be + holomorphic. Default False. + auto_infer: bool + Automatically infer all ``Variable`` instance. + + Returns + ------- + obj: ObjectTransform + The transformed object. + """ + if dyn_vars is None: + dyn_vars = infer_dyn_vars(func) if auto_infer else dict() + + dyn_vars, grad_vars, grad_tree = _grad_checking(func, dyn_vars, grad_vars) + has_aux = False if has_aux is None else has_aux + + if (len(dyn_vars) == 0) and (len(grad_vars) == 0): + argnums = 0 if argnums is None else argnums + return _jacfwd(fun=func, + argnums=argnums, + holomorphic=holomorphic, + has_aux=has_aux, + return_value=return_value) + else: + _argnums, _ = tree_flatten(argnums) + _argnums = tuple(a + 2 for a in _argnums) + if argnums is None and len(grad_vars) == 0: + raise errors.MathError('We detect no require to compute gradients because ' + '"grad_vars" is None and "argnums" is also None. ' + 'Please provide one of them.') + # computation + grad_func = _cls_jacfwd(func=func, + grad_vars=grad_vars, + dyn_vars=dyn_vars, + argnums=(0,) + _argnums, + has_aux=has_aux, + holomorphic=holomorphic) + + return GradientTransform(grad_func=grad_func, + grad_tree=grad_tree, + grad_vars=grad_vars, + dyn_vars=dyn_vars, + argnums=argnums, + return_value=return_value, + has_aux=has_aux, + origin_fun=func) + + +def hessian( + func, grad_vars=None, dyn_vars=None, argnums=None, + holomorphic=False, return_value=False, auto_infer=True +) -> ObjectTransform: + """Hessian of ``func`` as a dense array. + + Parameters + ---------- + func : callable, function + Function whose Hessian is to be computed. Its arguments at positions + specified by ``argnums`` should be arrays, scalars, or standard Python + containers thereof. It should return arrays, scalars, or standard Python + containers thereof. + dyn_vars : optional, ArrayCollector, sequence of ArrayType + The dynamical changed variables. + grad_vars : optional, ArrayCollector, sequence of ArrayType + The variables required to compute their gradients. + argnums: Optional, integer or sequence of integers + Specifies which positional argument(s) to differentiate with respect to (default ``0``). + holomorphic : bool + Indicates whether ``fun`` is promised to be holomorphic. Default False. + return_value : bool + Whether return the hessian values. + auto_infer: bool + Automatically infer all ``Variable`` instance. + + Returns + ------- + obj: ObjectTransform + The transformed object. + """ + if dyn_vars is None: + dyn_vars = infer_dyn_vars(func) if auto_infer else dict() + + dyn_vars, grad_vars, grad_tree = _grad_checking(func, dyn_vars, grad_vars) + argnums = 0 if argnums is None else argnums + + if (len(dyn_vars) == 0) and (len(grad_vars) == 0) and (not return_value): + f = jax.hessian(func, argnums=argnums, holomorphic=holomorphic) + return GradientFunTransform(f, dyn_vars=(), grad_vars=(), origin_fun=func) + else: + return jacfwd(jacrev(func, + dyn_vars=dyn_vars, + grad_vars=grad_vars, + argnums=argnums, + holomorphic=holomorphic), + dyn_vars=dyn_vars, + grad_vars=grad_vars, + argnums=argnums, + holomorphic=holomorphic, + return_value=return_value) + + +def _vector_grad(func, argnums=0, return_value=False, has_aux=False): + _check_callable(func) + + def grad_fun(*args, **kwargs): + f = linear_util.wrap_init(func, kwargs) + f_partial, dyn_args = argnums_partial(f, argnums, args, require_static_args_hashable=False) + if has_aux: + y, vjp_fn, aux = _vjp(f_partial, *dyn_args, has_aux=True) + else: + y, vjp_fn = _vjp(f_partial, *dyn_args, has_aux=False) + leaves, tree = tree_flatten(y) + tangents = tree_unflatten(tree, [jnp.ones_like(l) for l in leaves]) + grads = vjp_fn(tangents) + if isinstance(argnums, int): + grads = grads[0] + if has_aux: + return (grads, y, aux) if return_value else (grads, aux) + else: + return (grads, y) if return_value else grads + + return GradientFunTransform(grad_fun, (), (), origin_fun=func) + + +def _cls_vector_grad(func, grad_vars, dyn_vars, argnums, has_aux=False): + # parameters + assert isinstance(dyn_vars, (tuple, list)) # tuple/list of Variable + assert isinstance(grad_vars, (tuple, list)) # tuple/list of Variable + assert isinstance(argnums, (tuple, list)) # tuple/list of int + + # final functions + if has_aux: + @partial(_vector_grad, argnums=argnums, has_aux=True) + def grad_func(grad_values, dyn_values, *args, **kwargs): + for v, d in zip(dyn_vars, dyn_values): v._value = d + for v, d in zip(grad_vars, grad_values): v._value = d + outputs = func(*args, **kwargs) + output = outputs[0].value if isinstance(outputs[0], Array) else outputs[0] + return output, (outputs, [v.value for v in grad_vars], [v.value for v in dyn_vars]) + + else: + @partial(_vector_grad, argnums=argnums, has_aux=True) + def grad_func(grad_values, dyn_values, *args, **kwargs): + for v, d in zip(dyn_vars, dyn_values): v._value = d + for v, d in zip(grad_vars, grad_values): v._value = d + outputs = func(*args, **kwargs) + output = outputs.value if isinstance(outputs, Array) else outputs + return output, (outputs, [v.value for v in grad_vars], [v.value for v in dyn_vars]) + + return grad_func + + +def vector_grad( + func, dyn_vars=None, grad_vars=None, argnums=None, + return_value=False, has_aux=None, auto_infer=True +) -> ObjectTransform: + """Take vector-valued gradients for function ``func``. + + Same as `brainpy.math.grad <./brainpy.math.autograd.grad.html>`_, + `brainpy.math.jacrev <./brainpy.math.autograd.jacrev.html>`_ and + `brainpy.math.jacfwd <./brainpy.math.autograd.jacfwd.html>`_, + the returns in this function are different for different argument settings. + + 1. When "grad_vars" is None + - "has_aux=False" + "return_value=False" => ``arg_grads``. + - "has_aux=True" + "return_value=False" => ``(arg_grads, aux_data)``. + - "has_aux=False" + "return_value=True" => ``(arg_grads, loss_value)``. + - "has_aux=True" + "return_value=True" => ``(arg_grads, loss_value, aux_data)``. + 2. When "grad_vars" is not None and "argnums" is None + - "has_aux=False" + "return_value=False" => ``var_grads``. + - "has_aux=True" + "return_value=False" => ``(var_grads, aux_data)``. + - "has_aux=False" + "return_value=True" => ``(var_grads, loss_value)``. + - "has_aux=True" + "return_value=True" => ``(var_grads, loss_value, aux_data)``. + 3. When "grad_vars" is not None and "argnums" is not None + - "has_aux=False" + "return_value=False" => ``(var_grads, arg_grads)``. + - "has_aux=True" + "return_value=False" => ``((var_grads, arg_grads), aux_data)``. + - "has_aux=False" + "return_value=True" => ``((var_grads, arg_grads), loss_value)``. + - "has_aux=True" + "return_value=True" => ``((var_grads, arg_grads), loss_value, aux_data)``. + + + Parameters + ---------- + func: Function whose Jacobian is to be computed. + dyn_vars : optional, ArrayType, sequence of ArrayType, dict + The dynamically changed variables used in ``func``. + grad_vars : optional, ArrayType, sequence of ArrayType, dict + The variables in ``func`` to take their gradients. + has_aux: optional, bool + Indicates whether ``fun`` returns a pair where the + first element is considered the output of the mathematical function to be + differentiated and the second element is auxiliary data. Default False. + return_value : bool + Whether return the loss value. + argnums: Optional, integer or sequence of integers. Specifies which + positional argument(s) to differentiate with respect to (default ``0``). + auto_infer: bool + Automatically infer all ``Variable`` instance. + + Returns + ------- + func : ObjectTransform + The vector gradient function. + """ + if dyn_vars is None: + dyn_vars = infer_dyn_vars(func) if auto_infer else dict() + + dyn_vars, grad_vars, grad_tree = _grad_checking(func, dyn_vars, grad_vars) + has_aux = False if has_aux is None else has_aux + + if (len(dyn_vars) == 0) and (len(grad_vars) == 0): + argnums = 0 if argnums is None else argnums + return _vector_grad(func=func, + argnums=argnums, + return_value=return_value, + has_aux=has_aux) + + else: + _argnums, _ = tree_flatten(argnums) + _argnums = tuple(a + 2 for a in _argnums) + if argnums is None and len(grad_vars) == 0: + raise errors.MathError('We detect no require to compute gradients because ' + '"grad_vars" is None and "argnums" is also None. ' + 'Please provide one of them.') + # computation + grad_func = _cls_vector_grad(func=func, + grad_vars=grad_vars, + dyn_vars=dyn_vars, + argnums=(0,) + _argnums, + has_aux=has_aux) + + return GradientTransform(grad_func=grad_func, + grad_tree=grad_tree, + grad_vars=grad_vars, + dyn_vars=dyn_vars, + argnums=argnums, + return_value=return_value, + has_aux=has_aux, + origin_fun=func)