Skip to content

Commit

Permalink
[tune] Update pytorch-lightning integration API (ray-project#38883)
Browse files Browse the repository at this point in the history
Updates PTL callbacks to use new train.report API.

Signed-off-by: Kai Fricke <kai@anyscale.com>
Signed-off-by: matthewdeng <matthew.j.deng@gmail.com>
Co-authored-by: matthewdeng <matthew.j.deng@gmail.com>
Co-authored-by: Justin Yu <justinvyu@anyscale.com>
Signed-off-by: e428265 <arvind.chandramouli@lmco.com>
  • Loading branch information
3 people authored and arvind-chandra committed Aug 31, 2023
1 parent 667e498 commit c7fcaee
Show file tree
Hide file tree
Showing 4 changed files with 84 additions and 115 deletions.
4 changes: 3 additions & 1 deletion python/ray/tune/integration/lightgbm.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from lightgbm.callback import CallbackEnv
from lightgbm.basic import Booster
from ray.util.annotations import Deprecated


class TuneCallback:
Expand Down Expand Up @@ -195,6 +196,7 @@ def __init__(self, *args, **kwargs):
)


@Deprecated
class TuneReportCallback(TuneReportCheckpointCallback):
def __init__(
self,
Expand All @@ -203,7 +205,7 @@ def __init__(
Callable[[Dict[str, Union[float, List[float]]]], Dict[str, float]]
] = None,
):
if log_once("tune_report_deprecated"):
if log_once("tune_lightgbm_report_deprecated"):
warnings.warn(
"`ray.tune.integration.lightgbm.TuneReportCallback` is deprecated. "
"Use `ray.tune.integration.lightgbm.TuneCheckpointReportCallback` "
Expand Down
174 changes: 69 additions & 105 deletions python/ray/tune/integration/pytorch_lightning.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,19 @@
import inspect
import logging
import os
import tempfile
import warnings
from contextlib import contextmanager
from typing import Dict, List, Optional, Type, Union

from pytorch_lightning import Callback, Trainer, LightningModule
from ray import tune
from ray.util import PublicAPI
from ray import train
from ray.util import log_once
from ray.util.annotations import PublicAPI, Deprecated
from ray.air.checkpoint import Checkpoint as LegacyCheckpoint
from ray.train._checkpoint import Checkpoint
from ray.train._internal.storage import _use_storage_context

import os

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -72,55 +79,56 @@ def _handle(self, trainer: Trainer, pl_module: Optional[LightningModule]):


@PublicAPI
class TuneReportCallback(TuneCallback):
"""PyTorch Lightning to Ray Tune reporting callback
Reports metrics to Ray Tune.
.. note::
In Ray 2.4, we introduced
:class:`LightningTrainer <ray.train.lightning.LightningTrainer>`,
which provides native integration with PyTorch Lightning. Here is
:ref:`a simple example <lightning_mnist_example>` of how to use
``LightningTrainer``.
class TuneReportCheckpointCallback(TuneCallback):
"""PyTorch Lightning report and checkpoint callback
Saves checkpoints after each validation step. Also reports metrics to Tune,
which is needed for checkpoint registration.
Args:
metrics: Metrics to report to Tune. If this is a list,
each item describes the metric key reported to PyTorch Lightning,
and it will reported under the same name to Tune. If this is a
dict, each key will be the name reported to Tune and the respective
value will be the metric key reported to PyTorch Lightning.
on: When to trigger checkpoint creations. Must be one of
filename: Filename of the checkpoint within the checkpoint
directory. Defaults to "checkpoint".
save_checkpoints: If True (default), checkpoints will be saved and
reported to Ray. If False, only metrics will be reported.
on: When to trigger checkpoint creations and metric reports. Must be one of
the PyTorch Lightning event hooks (less the ``on_``), e.g.
"train_batch_start", or "train_end". Defaults to "validation_end".
Example:
.. code-block:: python
import pytorch_lightning as pl
from ray.tune.integration.pytorch_lightning import TuneReportCallback
from ray.tune.integration.pytorch_lightning import (
TuneReportCheckpointCallback)
# Report loss and accuracy to Tune after each validation epoch:
trainer = pl.Trainer(callbacks=[TuneReportCallback(
["val_loss", "val_acc"], on="validation_end")])
# Save checkpoint after each training batch and after each
# validation epoch.
trainer = pl.Trainer(callbacks=[TuneReportCheckpointCallback(
metrics={"loss": "val_loss", "mean_accuracy": "val_acc"},
filename="trainer.ckpt", on="validation_end")])
# Same as above, but report as `loss` and `mean_accuracy`:
trainer = pl.Trainer(callbacks=[TuneReportCallback(
{"loss": "val_loss", "mean_accuracy": "val_acc"},
on="validation_end")])
"""

def __init__(
self,
metrics: Optional[Union[str, List[str], Dict[str, str]]] = None,
filename: str = "checkpoint",
save_checkpoints: bool = True,
on: Union[str, List[str]] = "validation_end",
):
super(TuneReportCallback, self).__init__(on=on)
super(TuneReportCheckpointCallback, self).__init__(on=on)
if isinstance(metrics, str):
metrics = [metrics]
self._save_checkpoints = save_checkpoints
self._filename = filename
self._metrics = metrics

def _get_report_dict(self, trainer: Trainer, pl_module: LightningModule):
Expand All @@ -146,102 +154,58 @@ def _get_report_dict(self, trainer: Trainer, pl_module: LightningModule):

return report_dict

def _handle(self, trainer: Trainer, pl_module: LightningModule):
report_dict = self._get_report_dict(trainer, pl_module)
if report_dict is not None:
tune.report(**report_dict)


class _TuneCheckpointCallback(TuneCallback):
"""PyTorch Lightning checkpoint callback
Saves checkpoints after each validation step.
.. note::
In Ray 2.4, we introduced
:class:`LightningTrainer <ray.train.lightning.LightningTrainer>`,
which provides native integration with PyTorch Lightning. Here is
:ref:`a simple example <lightning_mnist_example>` of how to use
``LightningTrainer``.
Checkpoint are currently not registered if no ``tune.report()`` call
is made afterwards. Consider using ``TuneReportCheckpointCallback``
instead.
Args:
filename: Filename of the checkpoint within the checkpoint
directory. Defaults to "checkpoint".
on: When to trigger checkpoint creations. Must be one of
the PyTorch Lightning event hooks (less the ``on_``), e.g.
"train_batch_start", or "train_end". Defaults to "validation_end".
@contextmanager
def _get_checkpoint(
self, trainer: Trainer
) -> Optional[Union[Checkpoint, LegacyCheckpoint]]:
if not self._save_checkpoints:
yield None
return

with tempfile.TemporaryDirectory() as checkpoint_dir:
trainer.save_checkpoint(os.path.join(checkpoint_dir, self._filename))

"""
if _use_storage_context():
checkpoint = Checkpoint.from_directory(checkpoint_dir)
else:
checkpoint = LegacyCheckpoint.from_directory(checkpoint_dir)

def __init__(
self, filename: str = "checkpoint", on: Union[str, List[str]] = "validation_end"
):
super(_TuneCheckpointCallback, self).__init__(on)
self._filename = filename
yield checkpoint

def _handle(self, trainer: Trainer, pl_module: LightningModule):
if trainer.sanity_checking:
return
step = f"epoch={trainer.current_epoch}-step={trainer.global_step}"
with tune.checkpoint_dir(step=step) as checkpoint_dir:
trainer.save_checkpoint(os.path.join(checkpoint_dir, self._filename))


@PublicAPI
class TuneReportCheckpointCallback(TuneCallback):
"""PyTorch Lightning report and checkpoint callback
Saves checkpoints after each validation step. Also reports metrics to Tune,
which is needed for checkpoint registration.
Args:
metrics: Metrics to report to Tune. If this is a list,
each item describes the metric key reported to PyTorch Lightning,
and it will reported under the same name to Tune. If this is a
dict, each key will be the name reported to Tune and the respective
value will be the metric key reported to PyTorch Lightning.
filename: Filename of the checkpoint within the checkpoint
directory. Defaults to "checkpoint".
on: When to trigger checkpoint creations. Must be one of
the PyTorch Lightning event hooks (less the ``on_``), e.g.
"train_batch_start", or "train_end". Defaults to "validation_end".

report_dict = self._get_report_dict(trainer, pl_module)
if not report_dict:
return

Example:
.. code-block:: python
import pytorch_lightning as pl
from ray.tune.integration.pytorch_lightning import (
TuneReportCheckpointCallback)
# Save checkpoint after each training batch and after each
# validation epoch.
trainer = pl.Trainer(callbacks=[TuneReportCheckpointCallback(
metrics={"loss": "val_loss", "mean_accuracy": "val_acc"},
filename="trainer.ckpt", on="validation_end")])
with self._get_checkpoint(trainer) as checkpoint:
train.report(report_dict, checkpoint=checkpoint)


"""
class _TuneCheckpointCallback(TuneCallback):
def __init__(self, *args, **kwargs):
raise DeprecationWarning(
"`ray.tune.integration.pytorch_lightning._TuneCheckpointCallback` "
"is deprecated."
)

_checkpoint_callback_cls = _TuneCheckpointCallback
_report_callbacks_cls = TuneReportCallback

@Deprecated
class TuneReportCallback(TuneReportCheckpointCallback):
def __init__(
self,
metrics: Optional[Union[str, List[str], Dict[str, str]]] = None,
filename: str = "checkpoint",
on: Union[str, List[str]] = "validation_end",
):
super(TuneReportCheckpointCallback, self).__init__(on)
self._checkpoint = self._checkpoint_callback_cls(filename, on)
self._report = self._report_callbacks_cls(metrics, on)

def _handle(self, trainer: Trainer, pl_module: LightningModule):
self._checkpoint._handle(trainer, pl_module)
self._report._handle(trainer, pl_module)
if log_once("tune_ptl_report_deprecated"):
warnings.warn(
"`ray.tune.integration.pytorch_lightning.TuneReportCallback` "
"is deprecated. Use "
"`ray.tune.integration.pytorch_lightning.TuneReportCheckpointCallback`"
" instead."
)
super(TuneReportCallback, self).__init__(
metrics=metrics, save_checkpoints=False, on=on
)
4 changes: 3 additions & 1 deletion python/ray/tune/integration/xgboost.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from ray.train._internal.storage import _use_storage_context
from ray.tune.utils import flatten_dict
from ray.util import log_once
from ray.util.annotations import Deprecated
from xgboost.core import Booster

try:
Expand Down Expand Up @@ -188,6 +189,7 @@ def __init__(self, *args, **kwargs):
)


@Deprecated
class TuneReportCallback(TuneReportCheckpointCallback):
def __init__(
self,
Expand All @@ -196,7 +198,7 @@ def __init__(
Callable[[Dict[str, Union[float, List[float]]]], Dict[str, float]]
] = None,
):
if log_once("tune_report_deprecated"):
if log_once("tune_xgboost_report_deprecated"):
warnings.warn(
"`ray.tune.integration.xgboost.TuneReportCallback` is deprecated. "
"Use `ray.tune.integration.xgboost.TuneCheckpointReportCallback` "
Expand Down
17 changes: 9 additions & 8 deletions python/ray/tune/tests/test_integration_pytorch_lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from ray.tune.integration.pytorch_lightning import (
TuneReportCallback,
TuneReportCheckpointCallback,
_TuneCheckpointCallback,
)


Expand Down Expand Up @@ -89,7 +88,7 @@ def train(config):
max_epochs=1,
callbacks=[
TuneReportCallback(
{"tune_loss": "avg_val_loss"}, on="validation_end"
metrics={"tune_loss": "avg_val_loss"}, on="validation_end"
)
],
)
Expand All @@ -106,10 +105,10 @@ def testCheckpointCallback(self):
def train(config):
module = _MockModule(10.0, 20.0)
trainer = pl.Trainer(
max_epochs=1,
max_epochs=10,
callbacks=[
_TuneCheckpointCallback(
"trainer.ckpt", on=["batch_end", "train_end"]
TuneReportCheckpointCallback(
filename="trainer.ckpt", on=["train_epoch_end"]
)
],
)
Expand All @@ -128,8 +127,8 @@ def train(config):
for dir in os.listdir(analysis.trials[0].local_path)
if dir.startswith("checkpoint")
]
# 10 checkpoints after each batch, 1 checkpoint at end
self.assertEqual(len(checkpoints), 11)
# 1 checkpoint per epoch
self.assertEqual(len(checkpoints), 10)

def testReportCheckpointCallback(self):
tmpdir = tempfile.mkdtemp()
Expand All @@ -141,7 +140,9 @@ def train(config):
max_epochs=1,
callbacks=[
TuneReportCheckpointCallback(
["avg_val_loss"], "trainer.ckpt", on="validation_end"
metrics=["avg_val_loss"],
filename="trainer.ckpt",
on="validation_end",
)
],
)
Expand Down

0 comments on commit c7fcaee

Please sign in to comment.