diff --git a/brainpy/_src/math/object_transform/autograd.py b/brainpy/_src/math/object_transform/autograd.py index ad8a5ccf..2e5e103c 100644 --- a/brainpy/_src/math/object_transform/autograd.py +++ b/brainpy/_src/math/object_transform/autograd.py @@ -679,16 +679,213 @@ def jacfwd( transform_setting=dict(holomorphic=holomorphic)) +def _functional_hessian( + fun: Callable, + argnums: Optional[Union[int, Sequence[int]]] = None, + has_aux: bool = False, + holomorphic: bool = False, +): + return _jacfwd( + _jacrev(fun, argnums, has_aux=has_aux, holomorphic=holomorphic), + argnums, has_aux=has_aux, holomorphic=holomorphic + ) + + +class GradientTransformPreserveTree(ObjectTransform): + """ + Object-oriented Automatic Differentiation Transformation in BrainPy. + """ + + def __init__( + self, + target: Callable, + transform: Callable, + + # variables and nodes + grad_vars: Dict[str, Variable], + + # gradient setting + argnums: Optional[Union[int, Sequence[int]]], + return_value: bool, + has_aux: bool, + transform_setting: Optional[Dict[str, Any]] = None, + + # other + name: str = None, + ): + super().__init__(name=name) + + # gradient variables + if grad_vars is None: + grad_vars = dict() + assert isinstance(grad_vars, dict), 'grad_vars should be a dict' + new_grad_vars = {} + for k, v in grad_vars.items(): + assert isinstance(v, Variable) + new_grad_vars[k] = v + self._grad_vars = new_grad_vars + + # parameters + if argnums is None and len(self._grad_vars) == 0: + argnums = 0 + if argnums is None: + assert len(self._grad_vars) > 0 + _argnums = 0 + elif isinstance(argnums, int): + _argnums = (0, argnums + 2) if len(self._grad_vars) > 0 else (argnums + 2) + else: + _argnums = check.is_sequence(argnums, elem_type=int, allow_none=False) + _argnums = tuple(a + 2 for a in _argnums) + if len(self._grad_vars) > 0: + _argnums = (0,) + _argnums + self._nonvar_argnums = argnums + self._argnums = _argnums + self._return_value = return_value + self._has_aux = has_aux + + # target + self.target = target + + # transform + self._eval_dyn_vars = False + self._grad_transform = transform + self._dyn_vars = VariableStack() + self._transform = None + self._grad_setting = dict() if transform_setting is None else transform_setting + if self._has_aux: + self._transform = self._grad_transform( + self._f_grad_with_aux_to_transform, + argnums=self._argnums, + has_aux=True, + **self._grad_setting + ) + else: + self._transform = self._grad_transform( + self._f_grad_without_aux_to_transform, + argnums=self._argnums, + has_aux=True, + **self._grad_setting + ) + + def _f_grad_with_aux_to_transform(self, + grad_values: dict, + dyn_values: dict, + *args, + **kwargs): + for k in dyn_values.keys(): + self._dyn_vars[k]._value = dyn_values[k] + for k, v in grad_values.items(): + self._grad_vars[k]._value = v + # Users should return the auxiliary data like:: + # >>> # 1. example of return one data + # >>> return scalar_loss, data + # >>> # 2. example of return multiple data + # >>> return scalar_loss, (data1, data2, ...) + outputs = self.target(*args, **kwargs) + # outputs: [0] is the value for gradient, + # [1] is other values for return + output0 = tree_map(lambda a: (a.value if isinstance(a, Array) else a), outputs[0]) + return output0, (outputs, {k: v for k, v in self._grad_vars.items()}, self._dyn_vars.dict_data()) + + def _f_grad_without_aux_to_transform(self, + grad_values: dict, + dyn_values: dict, + *args, + **kwargs): + for k in dyn_values.keys(): + self._dyn_vars[k].value = dyn_values[k] + for k, v in grad_values.items(): + self._grad_vars[k].value = v + # Users should return the scalar value like this:: + # >>> return scalar_loss + output = self.target(*args, **kwargs) + output0 = tree_map(lambda a: (a.value if isinstance(a, Array) else a), output) + return output0, (output, {k: v.value for k, v in self._grad_vars.items()}, self._dyn_vars.dict_data()) + + def __repr__(self): + name = self.__class__.__name__ + f = tools.repr_object(self.target) + f = tools.repr_context(f, " " * (len(name) + 6)) + format_ref = (f'{name}({self.name}, target={f}, \n' + + f'{" " * len(name)} num_of_grad_vars={len(self._grad_vars)}, \n' + f'{" " * len(name)} num_of_dyn_vars={len(self._dyn_vars)})') + return format_ref + + def _return(self, rets): + grads, (outputs, new_grad_vs, new_dyn_vs) = rets + for k, v in new_grad_vs.items(): + self._grad_vars[k].value = v + for k in new_dyn_vs.keys(): + self._dyn_vars[k].value = new_dyn_vs[k] + + # check returned grads + if len(self._grad_vars) > 0: + if self._nonvar_argnums is None: + pass + else: + arg_grads = grads[1] if isinstance(self._nonvar_argnums, int) else grads[1:] + grads = (grads[0], arg_grads) + + # check returned value + if self._return_value: + # check aux + if self._has_aux: + return grads, outputs[0], outputs[1] + else: + return grads, outputs + else: + # check aux + if self._has_aux: + return grads, outputs[1] + else: + return grads + + def __call__(self, *args, **kwargs): + if jax.config.jax_disable_jit: # disable JIT + rets = self._transform( + {k: v.value for k, v in self._grad_vars.items()}, # variables for gradients + self._dyn_vars.dict_data(), # dynamical variables + *args, + **kwargs + ) + return self._return(rets) + + elif not self._eval_dyn_vars: # evaluate dynamical variables + stack = get_stack_cache(self.target) + if stack is None: + with VariableStack() as stack: + rets = eval_shape( + self._transform, + {k: v.value for k, v in self._grad_vars.items()}, # variables for gradients + {}, # dynamical variables + *args, + **kwargs + ) + cache_stack(self.target, stack) + + self._dyn_vars = stack + self._dyn_vars.remove_by_id(*[id(v) for v in self._grad_vars.values()]) + self._eval_dyn_vars = True + + # if not the outermost transformation + if not stack.is_first_stack(): + return self._return(rets) + + rets = self._transform( + {k: v.value for k, v in self._grad_vars.items()}, # variables for gradients + self._dyn_vars.dict_data(), # dynamical variables + *args, + **kwargs + ) + return self._return(rets) + + def hessian( func: Callable, grad_vars: Optional[Union[Variable, Sequence[Variable], Dict[str, Variable]]] = None, argnums: Optional[Union[int, Sequence[int]]] = None, - return_value: bool = False, + has_aux: Optional[bool] = None, holomorphic=False, - - # deprecated - dyn_vars: Optional[Union[Variable, Sequence[Variable], Dict[str, Variable]]] = None, - child_objs: Optional[Union[BrainPyObject, Sequence[BrainPyObject], Dict[str, BrainPyObject]]] = None, ) -> ObjectTransform: """Hessian of ``func`` as a dense array. @@ -705,42 +902,24 @@ def hessian( Specifies which positional argument(s) to differentiate with respect to (default ``0``). holomorphic : bool Indicates whether ``fun`` is promised to be holomorphic. Default False. - return_value : bool - Whether return the hessian values. - dyn_vars : optional, ArrayType, sequence of ArrayType, dict - The dynamically changed variables used in ``func``. - - .. deprecated:: 2.4.0 - No longer need to provide ``dyn_vars``. This function is capable of automatically - collecting the dynamical variables used in the target ``func``. - child_objs: optional, BrainPyObject, sequnce, dict - - .. versionadded:: 2.3.1 - - .. deprecated:: 2.4.0 - No longer need to provide ``child_objs``. This function is capable of automatically - collecting the children objects used in the target ``func``. + has_aux : bool, optional + Indicates whether ``fun`` returns a pair where the first element is + considered the output of the mathematical function to be differentiated + and the second element is auxiliary data. Default False. Returns ------- obj: ObjectTransform The transformed object. """ - child_objs = check.is_all_objs(child_objs, out_as='dict') - dyn_vars = check.is_all_vars(dyn_vars, out_as='dict') - return jacfwd(jacrev(func, - dyn_vars=dyn_vars, - child_objs=child_objs, - grad_vars=grad_vars, - argnums=argnums, - holomorphic=holomorphic), - dyn_vars=dyn_vars, - child_objs=child_objs, - grad_vars=grad_vars, - argnums=argnums, - holomorphic=holomorphic, - return_value=return_value) + return GradientTransformPreserveTree(target=func, + transform=jax.hessian, + grad_vars=grad_vars, + argnums=argnums, + has_aux=False if has_aux is None else has_aux, + transform_setting=dict(holomorphic=holomorphic), + return_value=False) def functional_vector_grad(func, argnums=0, return_value=False, has_aux=False): diff --git a/brainpy/_src/math/object_transform/tests/test_autograd.py b/brainpy/_src/math/object_transform/tests/test_autograd.py index b4fefc05..90829d80 100644 --- a/brainpy/_src/math/object_transform/tests/test_autograd.py +++ b/brainpy/_src/math/object_transform/tests/test_autograd.py @@ -1171,3 +1171,53 @@ def f(a, b): self.assertTrue(file.read().strip() == expect_res.strip()) + +class TestHessian(unittest.TestCase): + def test_hessian5(self): + bm.set_mode(bm.training_mode) + + class RNN(bp.DynamicalSystem): + def __init__(self, num_in, num_hidden): + super(RNN, self).__init__() + self.rnn = bp.dyn.RNNCell(num_in, num_hidden, train_state=True) + self.out = bp.dnn.Dense(num_hidden, 1) + + def update(self, x): + return self.out(self.rnn(x)) + + # define the loss function + def lossfunc(inputs, targets): + runner = bp.DSTrainer(model, progress_bar=False, numpy_mon_after_run=False) + predicts = runner.predict(inputs) + loss = bp.losses.mean_squared_error(predicts, targets) + return loss + + model = RNN(1, 2) + data_x = bm.random.rand(1, 1000, 1) + data_y = data_x + bm.random.randn(1, 1000, 1) + + bp.reset_state(model, 1) + losshess = bm.hessian(lossfunc, grad_vars=model.train_vars()) + hess_matrix = losshess(data_x, data_y) + + weights = model.train_vars().unique() + + # define the loss function + def loss_func_for_jax(weight_vals, inputs, targets): + for k, v in weight_vals.items(): + weights[k].value = v + runner = bp.DSTrainer(model, progress_bar=False, numpy_mon_after_run=False) + predicts = runner.predict(inputs) + loss = bp.losses.mean_squared_error(predicts, targets) + return loss + + bp.reset_state(model, 1) + jax_hessian = jax.hessian(loss_func_for_jax, argnums=0)({k: v.value for k, v in weights.items()}, data_x, data_y) + + for k, v in hess_matrix.items(): + for kk, vv in v.items(): + self.assertTrue(bm.allclose(vv, jax_hessian[k][kk], atol=1e-4)) + + bm.clear_buffer_memory() + + diff --git a/brainpy/_src/math/object_transform/tests/test_base.py b/brainpy/_src/math/object_transform/tests/test_base.py index ebad7eb0..4e1923e9 100644 --- a/brainpy/_src/math/object_transform/tests/test_base.py +++ b/brainpy/_src/math/object_transform/tests/test_base.py @@ -237,12 +237,12 @@ def test1(self): hh = bp.dyn.HH(1) hh.reset() - tree = jax.tree_structure(hh) - leaves = jax.tree_leaves(hh) + tree = jax.tree.structure(hh) + leaves = jax.tree.leaves(hh) print(tree) print(leaves) - print(jax.tree_unflatten(tree, leaves)) + print(jax.tree.unflatten(tree, leaves)) print() @@ -281,13 +281,13 @@ def not_close(x, y): def all_close(x, y): assert bm.allclose(x, y) - jax.tree_map(all_close, all_states, variables, is_leaf=bm.is_bp_array) + jax.tree.map(all_close, all_states, variables, is_leaf=bm.is_bp_array) - random_state = jax.tree_map(bm.random.rand_like, all_states, is_leaf=bm.is_bp_array) - jax.tree_map(not_close, random_state, variables, is_leaf=bm.is_bp_array) + random_state = jax.tree.map(bm.random.rand_like, all_states, is_leaf=bm.is_bp_array) + jax.tree.map(not_close, random_state, variables, is_leaf=bm.is_bp_array) obj.load_state_dict(random_state) - jax.tree_map(all_close, random_state, variables, is_leaf=bm.is_bp_array) + jax.tree.map(all_close, random_state, variables, is_leaf=bm.is_bp_array) diff --git a/requirements-dev.txt b/requirements-dev.txt index 98398ae2..641f99fd 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -8,6 +8,8 @@ tqdm pathos taichi numba +braincore +braintools # test requirements