New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Callback Error for example pytorch_lightning_simple.py #166
Comments
Hi, |
For the distributed training (: multi-processes), no. For single-process training, the old ver. callback: https://github.com/optuna/optuna/blob/release-v2.10.1/optuna/integration/pytorch_lightning.py might work fine, but I've not tested it. |
The shared code had a minor issue due to PyTorch-lighting default sanity check value. The following callback would be okay. from pytorch_lightning import LightningModule
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import Callback
class PyTorchLightningPruningCallback(Callback):
"""PyTorch Lightning callback to prune unpromising trials.
See `the example <https://github.com/optuna/optuna-examples/blob/
main/pytorch/pytorch_lightning_simple.py>`__
if you want to add a pruning callback which observes accuracy.
Args:
trial:
A :class:`~optuna.trial.Trial` corresponding to the current evaluation of the
objective function.
monitor:
An evaluation metric for pruning, e.g., ``val_loss`` or
``val_acc``. The metrics are obtained from the returned dictionaries from e.g.
``pytorch_lightning.LightningModule.training_step`` or
``pytorch_lightning.LightningModule.validation_epoch_end`` and the names thus depend on
how this dictionary is formatted.
"""
def __init__(self, trial: optuna.trial.Trial, monitor: str) -> None:
super().__init__()
self._trial = trial
self.monitor = monitor
def on_validation_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
# When the trainer calls `on_validation_end` for sanity check,
# do not call `trial.report` to avoid calling `trial.report` multiple times
# at epoch 0. The related page is
# https://github.com/PyTorchLightning/pytorch-lightning/issues/1391.
if trainer.sanity_checking:
return
epoch = pl_module.current_epoch
current_score = trainer.callback_metrics.get(self.monitor)
if current_score is None:
message = (
"The metric '{}' is not in the evaluation logs for pruning. "
"Please make sure you set the correct metric name.".format(self.monitor)
)
warnings.warn(message)
return
self._trial.report(current_score, step=epoch)
if self._trial.should_prune():
message = "Trial was pruned at epoch {}.".format(epoch)
raise optuna.TrialPruned(message) |
This worked for me, great thank you! |
When trying to run the example optuna-examples/pytorch_lightning_simple.py I get the runtime error: RuntimeError: The
on_init_start
callback hook was deprecated in v1.6 and is no longer supported as of v1.8.Environment
Error messages, stack traces, or logs
RuntimeError: The
on_init_start
callback hook was deprecated in v1.6 and is no longer supported as of v1.8.Additional context (optional)
If I comment out the callback part the code runs without problems. But this eliminates the pruning function which is quite important for the example. I could sadly not find a working example, so cannot really suggest a fix for lightning.
The text was updated successfully, but these errors were encountered: