Skip to content

Commit

Permalink
fix issue #661 (#662)
Browse files Browse the repository at this point in the history
* fix issue #661

* fix tests

* updates
  • Loading branch information
chaoming0625 committed Apr 14, 2024
1 parent b06d80a commit 4bd1898
Show file tree
Hide file tree
Showing 4 changed files with 272 additions and 41 deletions.
247 changes: 213 additions & 34 deletions brainpy/_src/math/object_transform/autograd.py
Expand Up @@ -679,16 +679,213 @@ def jacfwd(
transform_setting=dict(holomorphic=holomorphic))


def _functional_hessian(
fun: Callable,
argnums: Optional[Union[int, Sequence[int]]] = None,
has_aux: bool = False,
holomorphic: bool = False,
):
return _jacfwd(
_jacrev(fun, argnums, has_aux=has_aux, holomorphic=holomorphic),
argnums, has_aux=has_aux, holomorphic=holomorphic
)


class GradientTransformPreserveTree(ObjectTransform):
"""
Object-oriented Automatic Differentiation Transformation in BrainPy.
"""

def __init__(
self,
target: Callable,
transform: Callable,

# variables and nodes
grad_vars: Dict[str, Variable],

# gradient setting
argnums: Optional[Union[int, Sequence[int]]],
return_value: bool,
has_aux: bool,
transform_setting: Optional[Dict[str, Any]] = None,

# other
name: str = None,
):
super().__init__(name=name)

# gradient variables
if grad_vars is None:
grad_vars = dict()
assert isinstance(grad_vars, dict), 'grad_vars should be a dict'
new_grad_vars = {}
for k, v in grad_vars.items():
assert isinstance(v, Variable)
new_grad_vars[k] = v
self._grad_vars = new_grad_vars

# parameters
if argnums is None and len(self._grad_vars) == 0:
argnums = 0
if argnums is None:
assert len(self._grad_vars) > 0
_argnums = 0
elif isinstance(argnums, int):
_argnums = (0, argnums + 2) if len(self._grad_vars) > 0 else (argnums + 2)
else:
_argnums = check.is_sequence(argnums, elem_type=int, allow_none=False)
_argnums = tuple(a + 2 for a in _argnums)
if len(self._grad_vars) > 0:
_argnums = (0,) + _argnums
self._nonvar_argnums = argnums
self._argnums = _argnums
self._return_value = return_value
self._has_aux = has_aux

# target
self.target = target

# transform
self._eval_dyn_vars = False
self._grad_transform = transform
self._dyn_vars = VariableStack()
self._transform = None
self._grad_setting = dict() if transform_setting is None else transform_setting
if self._has_aux:
self._transform = self._grad_transform(
self._f_grad_with_aux_to_transform,
argnums=self._argnums,
has_aux=True,
**self._grad_setting
)
else:
self._transform = self._grad_transform(
self._f_grad_without_aux_to_transform,
argnums=self._argnums,
has_aux=True,
**self._grad_setting
)

def _f_grad_with_aux_to_transform(self,
grad_values: dict,
dyn_values: dict,
*args,
**kwargs):
for k in dyn_values.keys():
self._dyn_vars[k]._value = dyn_values[k]
for k, v in grad_values.items():
self._grad_vars[k]._value = v
# Users should return the auxiliary data like::
# >>> # 1. example of return one data
# >>> return scalar_loss, data
# >>> # 2. example of return multiple data
# >>> return scalar_loss, (data1, data2, ...)
outputs = self.target(*args, **kwargs)
# outputs: [0] is the value for gradient,
# [1] is other values for return
output0 = tree_map(lambda a: (a.value if isinstance(a, Array) else a), outputs[0])
return output0, (outputs, {k: v for k, v in self._grad_vars.items()}, self._dyn_vars.dict_data())

def _f_grad_without_aux_to_transform(self,
grad_values: dict,
dyn_values: dict,
*args,
**kwargs):
for k in dyn_values.keys():
self._dyn_vars[k].value = dyn_values[k]
for k, v in grad_values.items():
self._grad_vars[k].value = v
# Users should return the scalar value like this::
# >>> return scalar_loss
output = self.target(*args, **kwargs)
output0 = tree_map(lambda a: (a.value if isinstance(a, Array) else a), output)
return output0, (output, {k: v.value for k, v in self._grad_vars.items()}, self._dyn_vars.dict_data())

def __repr__(self):
name = self.__class__.__name__
f = tools.repr_object(self.target)
f = tools.repr_context(f, " " * (len(name) + 6))
format_ref = (f'{name}({self.name}, target={f}, \n' +
f'{" " * len(name)} num_of_grad_vars={len(self._grad_vars)}, \n'
f'{" " * len(name)} num_of_dyn_vars={len(self._dyn_vars)})')
return format_ref

def _return(self, rets):
grads, (outputs, new_grad_vs, new_dyn_vs) = rets
for k, v in new_grad_vs.items():
self._grad_vars[k].value = v
for k in new_dyn_vs.keys():
self._dyn_vars[k].value = new_dyn_vs[k]

# check returned grads
if len(self._grad_vars) > 0:
if self._nonvar_argnums is None:
pass
else:
arg_grads = grads[1] if isinstance(self._nonvar_argnums, int) else grads[1:]
grads = (grads[0], arg_grads)

# check returned value
if self._return_value:
# check aux
if self._has_aux:
return grads, outputs[0], outputs[1]
else:
return grads, outputs
else:
# check aux
if self._has_aux:
return grads, outputs[1]
else:
return grads

def __call__(self, *args, **kwargs):
if jax.config.jax_disable_jit: # disable JIT
rets = self._transform(
{k: v.value for k, v in self._grad_vars.items()}, # variables for gradients
self._dyn_vars.dict_data(), # dynamical variables
*args,
**kwargs
)
return self._return(rets)

elif not self._eval_dyn_vars: # evaluate dynamical variables
stack = get_stack_cache(self.target)
if stack is None:
with VariableStack() as stack:
rets = eval_shape(
self._transform,
{k: v.value for k, v in self._grad_vars.items()}, # variables for gradients
{}, # dynamical variables
*args,
**kwargs
)
cache_stack(self.target, stack)

self._dyn_vars = stack
self._dyn_vars.remove_by_id(*[id(v) for v in self._grad_vars.values()])
self._eval_dyn_vars = True

# if not the outermost transformation
if not stack.is_first_stack():
return self._return(rets)

rets = self._transform(
{k: v.value for k, v in self._grad_vars.items()}, # variables for gradients
self._dyn_vars.dict_data(), # dynamical variables
*args,
**kwargs
)
return self._return(rets)


def hessian(
func: Callable,
grad_vars: Optional[Union[Variable, Sequence[Variable], Dict[str, Variable]]] = None,
argnums: Optional[Union[int, Sequence[int]]] = None,
return_value: bool = False,
has_aux: Optional[bool] = None,
holomorphic=False,

# deprecated
dyn_vars: Optional[Union[Variable, Sequence[Variable], Dict[str, Variable]]] = None,
child_objs: Optional[Union[BrainPyObject, Sequence[BrainPyObject], Dict[str, BrainPyObject]]] = None,
) -> ObjectTransform:
"""Hessian of ``func`` as a dense array.
Expand All @@ -705,42 +902,24 @@ def hessian(
Specifies which positional argument(s) to differentiate with respect to (default ``0``).
holomorphic : bool
Indicates whether ``fun`` is promised to be holomorphic. Default False.
return_value : bool
Whether return the hessian values.
dyn_vars : optional, ArrayType, sequence of ArrayType, dict
The dynamically changed variables used in ``func``.
.. deprecated:: 2.4.0
No longer need to provide ``dyn_vars``. This function is capable of automatically
collecting the dynamical variables used in the target ``func``.
child_objs: optional, BrainPyObject, sequnce, dict
.. versionadded:: 2.3.1
.. deprecated:: 2.4.0
No longer need to provide ``child_objs``. This function is capable of automatically
collecting the children objects used in the target ``func``.
has_aux : bool, optional
Indicates whether ``fun`` returns a pair where the first element is
considered the output of the mathematical function to be differentiated
and the second element is auxiliary data. Default False.
Returns
-------
obj: ObjectTransform
The transformed object.
"""
child_objs = check.is_all_objs(child_objs, out_as='dict')
dyn_vars = check.is_all_vars(dyn_vars, out_as='dict')

return jacfwd(jacrev(func,
dyn_vars=dyn_vars,
child_objs=child_objs,
grad_vars=grad_vars,
argnums=argnums,
holomorphic=holomorphic),
dyn_vars=dyn_vars,
child_objs=child_objs,
grad_vars=grad_vars,
argnums=argnums,
holomorphic=holomorphic,
return_value=return_value)
return GradientTransformPreserveTree(target=func,
transform=jax.hessian,
grad_vars=grad_vars,
argnums=argnums,
has_aux=False if has_aux is None else has_aux,
transform_setting=dict(holomorphic=holomorphic),
return_value=False)


def functional_vector_grad(func, argnums=0, return_value=False, has_aux=False):
Expand Down
50 changes: 50 additions & 0 deletions brainpy/_src/math/object_transform/tests/test_autograd.py
Expand Up @@ -1171,3 +1171,53 @@ def f(a, b):
self.assertTrue(file.read().strip() == expect_res.strip())



class TestHessian(unittest.TestCase):
def test_hessian5(self):
bm.set_mode(bm.training_mode)

class RNN(bp.DynamicalSystem):
def __init__(self, num_in, num_hidden):
super(RNN, self).__init__()
self.rnn = bp.dyn.RNNCell(num_in, num_hidden, train_state=True)
self.out = bp.dnn.Dense(num_hidden, 1)

def update(self, x):
return self.out(self.rnn(x))

# define the loss function
def lossfunc(inputs, targets):
runner = bp.DSTrainer(model, progress_bar=False, numpy_mon_after_run=False)
predicts = runner.predict(inputs)
loss = bp.losses.mean_squared_error(predicts, targets)
return loss

model = RNN(1, 2)
data_x = bm.random.rand(1, 1000, 1)
data_y = data_x + bm.random.randn(1, 1000, 1)

bp.reset_state(model, 1)
losshess = bm.hessian(lossfunc, grad_vars=model.train_vars())
hess_matrix = losshess(data_x, data_y)

weights = model.train_vars().unique()

# define the loss function
def loss_func_for_jax(weight_vals, inputs, targets):
for k, v in weight_vals.items():
weights[k].value = v
runner = bp.DSTrainer(model, progress_bar=False, numpy_mon_after_run=False)
predicts = runner.predict(inputs)
loss = bp.losses.mean_squared_error(predicts, targets)
return loss

bp.reset_state(model, 1)
jax_hessian = jax.hessian(loss_func_for_jax, argnums=0)({k: v.value for k, v in weights.items()}, data_x, data_y)

for k, v in hess_matrix.items():
for kk, vv in v.items():
self.assertTrue(bm.allclose(vv, jax_hessian[k][kk], atol=1e-4))

bm.clear_buffer_memory()


14 changes: 7 additions & 7 deletions brainpy/_src/math/object_transform/tests/test_base.py
Expand Up @@ -237,12 +237,12 @@ def test1(self):
hh = bp.dyn.HH(1)
hh.reset()

tree = jax.tree_structure(hh)
leaves = jax.tree_leaves(hh)
tree = jax.tree.structure(hh)
leaves = jax.tree.leaves(hh)

print(tree)
print(leaves)
print(jax.tree_unflatten(tree, leaves))
print(jax.tree.unflatten(tree, leaves))
print()


Expand Down Expand Up @@ -281,13 +281,13 @@ def not_close(x, y):
def all_close(x, y):
assert bm.allclose(x, y)

jax.tree_map(all_close, all_states, variables, is_leaf=bm.is_bp_array)
jax.tree.map(all_close, all_states, variables, is_leaf=bm.is_bp_array)

random_state = jax.tree_map(bm.random.rand_like, all_states, is_leaf=bm.is_bp_array)
jax.tree_map(not_close, random_state, variables, is_leaf=bm.is_bp_array)
random_state = jax.tree.map(bm.random.rand_like, all_states, is_leaf=bm.is_bp_array)
jax.tree.map(not_close, random_state, variables, is_leaf=bm.is_bp_array)

obj.load_state_dict(random_state)
jax.tree_map(all_close, random_state, variables, is_leaf=bm.is_bp_array)
jax.tree.map(all_close, random_state, variables, is_leaf=bm.is_bp_array)



Expand Down
2 changes: 2 additions & 0 deletions requirements-dev.txt
Expand Up @@ -8,6 +8,8 @@ tqdm
pathos
taichi
numba
braincore
braintools


# test requirements
Expand Down

0 comments on commit 4bd1898

Please sign in to comment.