Skip to content

Commit

Permalink
Merge pull request #167 from mblondel:lbfgs_value
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 425915191
  • Loading branch information
JAXopt authors committed Feb 2, 2022
2 parents 2746140 + c742e25 commit eb6e75d
Showing 1 changed file with 8 additions and 1 deletion.
9 changes: 8 additions & 1 deletion jaxopt/_src/lbfgs.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ def update_history(history_pytree, new_pytree, last):
class LbfgsState(NamedTuple):
"""Named tuple containing state information."""
iter_num: int
value: float
stepsize: float
error: float
s_history: Any
Expand Down Expand Up @@ -216,6 +217,7 @@ def init_state(self,
state
"""
return LbfgsState(iter_num=jnp.asarray(0),
value=jnp.asarray(jnp.inf),
stepsize=jnp.asarray(1.0),
error=jnp.asarray(jnp.inf),
s_history=init_history(init_params, self.history_size),
Expand Down Expand Up @@ -263,6 +265,7 @@ def update(self,
params=params, value=value, grad=grad,
descent_direction=descent_direction,
*args, **kwargs)
new_value = ls_state.value
new_params = ls_state.params
new_grad = ls_state.grad

Expand All @@ -276,11 +279,15 @@ def update(self,
rho_history = update_history(state.rho_history, rho, last)

new_state = LbfgsState(iter_num=state.iter_num + 1,
value=new_value,
stepsize=jnp.asarray(new_stepsize),
error=tree_l2_norm(grad),
error=tree_l2_norm(new_grad),
s_history=s_history,
y_history=y_history,
rho_history=rho_history,
# FIXME: we should return new_aux here but
# BacktrackingLineSearch currently doesn't support
# an aux.
aux=aux)

return base.OptStep(params=new_params, state=new_state)
Expand Down

0 comments on commit eb6e75d

Please sign in to comment.