Skip to content

Commit

Permalink
Update autograd.py
Browse files Browse the repository at this point in the history
  • Loading branch information
Routhleck committed May 10, 2024
1 parent 25e0158 commit 4e3151e
Showing 1 changed file with 27 additions and 8 deletions.
35 changes: 27 additions & 8 deletions brainpy/_src/math/object_transform/autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -884,8 +884,12 @@ def hessian(
func: Callable,
grad_vars: Optional[Union[Variable, Sequence[Variable], Dict[str, Variable]]] = None,
argnums: Optional[Union[int, Sequence[int]]] = None,
has_aux: Optional[bool] = None,
return_value: bool = False,
holomorphic=False,

# deprecated
dyn_vars: Optional[Union[Variable, Sequence[Variable], Dict[str, Variable]]] = None,
child_objs: Optional[Union[BrainPyObject, Sequence[BrainPyObject], Dict[str, BrainPyObject]]] = None,
) -> ObjectTransform:
"""Hessian of ``func`` as a dense array.
Expand All @@ -912,14 +916,29 @@ def hessian(
obj: ObjectTransform
The transformed object.
"""
child_objs = check.is_all_objs(child_objs, out_as='dict')
dyn_vars = check.is_all_vars(dyn_vars, out_as='dict')

return GradientTransformPreserveTree(target=func,
transform=jax.hessian,
grad_vars=grad_vars,
argnums=argnums,
has_aux=False if has_aux is None else has_aux,
transform_setting=dict(holomorphic=holomorphic),
return_value=False)
return jacfwd(jacrev(func,
dyn_vars=dyn_vars,
child_objs=child_objs,
grad_vars=grad_vars,
argnums=argnums,
holomorphic=holomorphic),
dyn_vars=dyn_vars,
child_objs=child_objs,
grad_vars=grad_vars,
argnums=argnums,
holomorphic=holomorphic,
return_value=return_value)

# return GradientTransformPreserveTree(target=func,
# transform=jax.hessian,
# grad_vars=grad_vars,
# argnums=argnums,
# has_aux=False if has_aux is None else has_aux,
# transform_setting=dict(holomorphic=holomorphic),
# return_value=False)


def functional_vector_grad(func, argnums=0, return_value=False, has_aux=False):
Expand Down

0 comments on commit 4e3151e

Please sign in to comment.