Skip to content

Commit

Permalink
update stopping
Browse files Browse the repository at this point in the history
  • Loading branch information
jobrachem committed Apr 4, 2024
1 parent db5439f commit 070964e
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 7 deletions.
29 changes: 23 additions & 6 deletions liesel_ptm/optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ class Stopper:
max_iter: int
patience: int
atol: float = 1e-3
rtol: float = 1e-6
rtol: float = 1e-12

def stop_early(self, i: int | Array, loss_history: Array):
"""
Expand All @@ -122,13 +122,30 @@ def stop_early(self, i: int | Array, loss_history: Array):
best_loss_in_recent = jnp.min(recent_history)
current_loss = loss_history[i]

change = best_loss_in_recent - current_loss

change = current_loss - best_loss_in_recent
"""
If current_loss is better than best_loss_in_recent, this is negative.
If current_loss is worse, this is positive.
"""
rel_change = jnp.abs(jnp.abs(change) / best_loss_in_recent)

no_improvement = change < self.atol
no_rel_change = rel_change < self.rtol
return no_improvement & no_rel_change & (i > p)
no_improvement = change > self.atol
"""
If the current loss has not improved upon the best loss in the patience
period, we always want to stop. However, we actually allow for slightly
worse losses, defined by the absolute tolerance here.
"""

no_rel_change = ~no_improvement & (rel_change < self.rtol)
"""
Let's say the current value *does* improve upon the best value within patience,
such that no_improvement=False.
In this case, if the improvement is very small compared to the best observed
loss in the patience period, we may still want to stop.
"""

return (no_improvement | no_rel_change) & (i > p)

def stop_now(self, i: int | Array, loss_history: Array):
"""Whether optimization should stop now."""
Expand Down
3 changes: 2 additions & 1 deletion tests/test_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,7 @@ def test_stopper_does_not_stop(self):

key = jax.random.PRNGKey(42)
loss_history = jax.random.uniform(key, shape=(15,))
loss_history = loss_history.at[6].set(-0.1)

stop = stopper.stop_early(i=6, loss_history=loss_history)

Expand All @@ -245,7 +246,7 @@ def test_stopper_jitted(self):
key = jax.random.PRNGKey(42)
loss_history = jax.random.uniform(key, shape=(15,))
stop = stop_jit(i=6, loss_history=loss_history)
assert not stop
assert stop

def test_stop_at_jitted(self):
stopper = Stopper(patience=5, max_iter=100, atol=0.1, rtol=0.1)
Expand Down

0 comments on commit 070964e

Please sign in to comment.