Skip to content

Commit

Permalink
bugfix: LR finder. Reset optimizer state
Browse files Browse the repository at this point in the history
  • Loading branch information
torzdf committed Sep 8, 2023
1 parent 8388241 commit 48aca14
Showing 1 changed file with 33 additions and 11 deletions.
44 changes: 33 additions & 11 deletions lib/training/lr_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,34 @@ def _train(self) -> None:
self._on_batch_end(idx, loss[0])
self._update_description(pbar)

def _reset_model(self, original_lr: float, new_lr: float) -> None:
""" Reset the model's weights to initial values, reset the model's optimizer and set the
learning rate
Parameters
----------
original_lr: float
The model's original learning rate
new_lr: float
The discovered optimal learning rate
"""
self._model.state.update_session_config("learning_rate", new_lr)
self._model.state.save()

logger.debug("Loading initial weights")
self._model.model.load_weights(self._model.io.filename)

if self._config["lr_finder_mode"] == "graph_and_exit":
return

opt_conf = self._model.model.optimizer.get_config()
logger.debug("Recompiling model to reset optimizer state. Optimizer config: %s", opt_conf)
new_opt = self._model.model.optimizer.__class__(**opt_conf)
self._model.model.compile(optimizer=new_opt, loss=self._model.model.loss)

logger.info("Updating Learning Rate from %s to %s", f"{original_lr:.1e}", f"{new_lr:.1e}")
K.set_value(self._model.model.optimizer.lr, new_lr)

def find(self) -> bool:
""" Find the optimal learning rate
Expand All @@ -162,17 +190,8 @@ def find(self) -> bool:
shutil.rmtree(self._model.io.model_dir)
return False

if self._save_graph:
self._plot_loss()

if not self._config["lr_finder_mode"] == "graph_and_exit":
logger.info("Updating Learning Rate from %s to %s",
f"{original_lr:.1e}", f"{new_lr:.1e}")
self._model.model.load_weights(self._model.io.filename)
K.set_value(self._model.model.optimizer.lr, new_lr)

self._model.state.update_session_config("learning_rate", new_lr)
self._model.state.save()
self._plot_loss()
self._reset_model(original_lr, new_lr)
return True

def _plot_loss(self, skip_begin: int = 10, skip_end: int = 1) -> None:
Expand All @@ -185,6 +204,9 @@ def _plot_loss(self, skip_begin: int = 10, skip_end: int = 1) -> None:
skip_end: int, optional
Number of iterations to skip at the end. Default: `1`
"""
if not self._save_graph:
return

matplotlib.use("Agg")
lrs = self._metrics["learning_rates"][skip_begin:-skip_end]
losses = self._metrics["losses"][skip_begin:-skip_end]
Expand Down

0 comments on commit 48aca14

Please sign in to comment.