Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Callback Error for example pytorch_lightning_simple.py #166

Closed
leoguen opened this issue Jan 24, 2023 · 5 comments
Closed

Callback Error for example pytorch_lightning_simple.py #166

leoguen opened this issue Jan 24, 2023 · 5 comments
Labels
bug Issue/PR about behavior that is broken. Not for typos/CI but for example itself.

Comments

@leoguen
Copy link

leoguen commented Jan 24, 2023

When trying to run the example optuna-examples/pytorch_lightning_simple.py I get the runtime error: RuntimeError: The on_init_start callback hook was deprecated in v1.6 and is no longer supported as of v1.8.

Environment

  • Optuna version: 3.1.0
  • Python version: 3.8.10
  • OS: Ubuntu 22.04
  • (Optional) Other libraries and their versions:

Error messages, stack traces, or logs

RuntimeError: The on_init_start callback hook was deprecated in v1.6 and is no longer supported as of v1.8.

Additional context (optional)

If I comment out the callback part the code runs without problems. But this eliminates the pruning function which is quite important for the example. I could sadly not find a working example, so cannot really suggest a fix for lightning.

    trainer = pl.Trainer(
        logger=True,
        limit_val_batches=PERCENT_VALID_EXAMPLES,
        enable_checkpointing=False,
        max_epochs=EPOCHS,
        gpus=1 if torch.cuda.is_available() else None,
        #callbacks=[PyTorchLightningPruningCallback(trial, monitor="val_acc")],
    )
@leoguen leoguen added the bug Issue/PR about behavior that is broken. Not for typos/CI but for example itself. label Jan 24, 2023
@nzw0301
Copy link
Member

nzw0301 commented Jan 25, 2023

#148 (comment)

@nzw0301 nzw0301 closed this as completed Jan 25, 2023
@leoguen
Copy link
Author

leoguen commented Jan 25, 2023

Hi,
I understand that for the example the requirements are set as lightning being between 1.5 and 1.6. As someone who is trying to build from this example, is there another common way to implement the pruning with the recent lightning version?
Cheers,
Leo

@nzw0301
Copy link
Member

nzw0301 commented Jan 25, 2023

For the distributed training (: multi-processes), no. For single-process training, the old ver. callback: https://github.com/optuna/optuna/blob/release-v2.10.1/optuna/integration/pytorch_lightning.py might work fine, but I've not tested it.

@nzw0301
Copy link
Member

nzw0301 commented Jan 25, 2023

The shared code had a minor issue due to PyTorch-lighting default sanity check value. The following callback would be okay.

from pytorch_lightning import LightningModule
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import Callback

class PyTorchLightningPruningCallback(Callback):
    """PyTorch Lightning callback to prune unpromising trials.
    See `the example <https://github.com/optuna/optuna-examples/blob/
    main/pytorch/pytorch_lightning_simple.py>`__
    if you want to add a pruning callback which observes accuracy.
    Args:
        trial:
            A :class:`~optuna.trial.Trial` corresponding to the current evaluation of the
            objective function.
        monitor:
            An evaluation metric for pruning, e.g., ``val_loss`` or
            ``val_acc``. The metrics are obtained from the returned dictionaries from e.g.
            ``pytorch_lightning.LightningModule.training_step`` or
            ``pytorch_lightning.LightningModule.validation_epoch_end`` and the names thus depend on
            how this dictionary is formatted.
    """

    def __init__(self, trial: optuna.trial.Trial, monitor: str) -> None:
        super().__init__()

        self._trial = trial
        self.monitor = monitor

    def on_validation_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
        # When the trainer calls `on_validation_end` for sanity check,
        # do not call `trial.report` to avoid calling `trial.report` multiple times
        # at epoch 0. The related page is
        # https://github.com/PyTorchLightning/pytorch-lightning/issues/1391.
        if trainer.sanity_checking:
            return

        epoch = pl_module.current_epoch

        current_score = trainer.callback_metrics.get(self.monitor)
        if current_score is None:
            message = (
                "The metric '{}' is not in the evaluation logs for pruning. "
                "Please make sure you set the correct metric name.".format(self.monitor)
            )
            warnings.warn(message)
            return

        self._trial.report(current_score, step=epoch)
        if self._trial.should_prune():
            message = "Trial was pruned at epoch {}.".format(epoch)
            raise optuna.TrialPruned(message)

@leoguen
Copy link
Author

leoguen commented Feb 3, 2023

This worked for me, great thank you!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Issue/PR about behavior that is broken. Not for typos/CI but for example itself.
Projects
None yet
Development

No branches or pull requests

2 participants