Skip to content

Commit

Permalink
[AIR] Remove extra sessions that are not needed any more (ray-project…
Browse files Browse the repository at this point in the history
…#37023)

Having multiple sessions floating around is confusing and we are going to replace the session concept with a unified context object between train and tune going forward (see ray-project#36706)

The changes in detail:

- Remove the `Session` interface class -- we are not planning to expose it to the user and it just introduces an additional level of abstraction that is not needed / not aligned with the longer term plan of having a unified context object between train and tune

- Remove the `_TrainSessionImpl` and `_TuneSessionImpl` and instead push the functionality down into the `_StatusReporter` and the `_TrainSession` -- we might want to rename `_StatusReporter` to `_TuneSession` to be more consistent.

Signed-off-by: e428265 <arvind.chandramouli@lmco.com>
  • Loading branch information
pcmoritz authored and arvind-chandra committed Aug 31, 2023
1 parent 6a2e964 commit e5be08d
Show file tree
Hide file tree
Showing 7 changed files with 54 additions and 242 deletions.
87 changes: 3 additions & 84 deletions python/ray/air/_internal/session.py
Original file line number Diff line number Diff line change
@@ -1,93 +1,12 @@
import abc
import logging
from typing import TYPE_CHECKING, Dict, Optional

from ray.air.checkpoint import Checkpoint

if TYPE_CHECKING:
from ray.tune.execution.placement_groups import PlacementGroupFactory

logger = logging.getLogger(__name__)


class Session(abc.ABC):
"""The canonical session interface that both Tune and Train session implements.
User can interact with this interface to get session information,
as well as reporting metrics and saving checkpoint.
"""

@abc.abstractmethod
def report(self, metrics: Dict, *, checkpoint: Optional[Checkpoint] = None) -> None:
"""Report metrics and optionally save checkpoint.
Each invocation of this method will automatically increment the underlying
iteration number. The physical meaning of this "iteration" is defined by
user (or more specifically the way they call ``report``).
It does not necessarily map to one epoch.
This API is supposed to replace the legacy ``tune.report``,
``with tune.checkpoint_dir``, ``train.report`` and ``train.save_checkpoint``.
Please avoid mixing them together.
There is no requirement on what is the underlying representation of the
checkpoint.
All forms are accepted and (will be) handled by AIR in an efficient way.
Specifically, if you are passing in a directory checkpoint, AIR will move
the content of the directory to AIR managed directory. By the return of this
method, one may safely write new content to the original directory without
interfering with AIR checkpointing flow.
Args:
metrics: The metrics you want to report.
checkpoint: The optional checkpoint you want to report.
"""

raise NotImplementedError

@property
@abc.abstractmethod
def loaded_checkpoint(self) -> Optional[Checkpoint]:
"""Access the session's loaded checkpoint to resume from if applicable.
Returns:
Checkpoint object if the session is currently being resumed.
Otherwise, return None.
"""

raise NotImplementedError

@property
def experiment_name(self) -> str:
"""Experiment name for the corresponding trial."""
raise NotImplementedError

@property
def trial_name(self) -> str:
"""Trial name for the corresponding trial."""
raise NotImplementedError

@property
def trial_id(self) -> str:
"""Trial id for the corresponding trial."""
raise NotImplementedError

@property
def trial_resources(self) -> "PlacementGroupFactory":
"""Trial resources for the corresponding trial."""
raise NotImplementedError

@property
def trial_dir(self) -> str:
"""Trial-level log directory for the corresponding trial."""
raise NotImplementedError


def _get_session(warn: bool = True) -> Optional[Session]:
from ray.train._internal.session import _session_v2 as train_session
from ray.tune.trainable.session import _session_v2 as tune_session
def _get_session(warn: bool = True):
from ray.train._internal.session import _session as train_session
from ray.tune.trainable.session import _session as tune_session

if train_session and tune_session:
if warn:
Expand Down
13 changes: 6 additions & 7 deletions python/ray/air/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from ray.air._internal.session import _get_session
from ray.air.checkpoint import Checkpoint
from ray.air.constants import SESSION_MISUSE_LOG_ONCE_KEY
from ray.train.session import _TrainSessionImpl
from ray.util import log_once
from ray.util.annotations import PublicAPI

Expand Down Expand Up @@ -216,7 +215,7 @@ def train_loop_per_worker(config):
trainer.fit()
"""
session = _get_session()
if not isinstance(session, _TrainSessionImpl):
if not hasattr(session, "world_size"):
raise RuntimeError(
"`get_world_size` can only be called for TrainSession! "
"Make sure you only use that in `train_loop_per_worker` function"
Expand Down Expand Up @@ -250,7 +249,7 @@ def train_loop_per_worker():
trainer.fit()
"""
session = _get_session()
if not isinstance(session, _TrainSessionImpl):
if not hasattr(session, "world_rank"):
raise RuntimeError(
"`get_world_rank` can only be called for TrainSession! "
"Make sure you only use that in `train_loop_per_worker` function"
Expand Down Expand Up @@ -283,7 +282,7 @@ def train_loop_per_worker():
trainer.fit()
"""
session = _get_session()
if not isinstance(session, _TrainSessionImpl):
if not hasattr(session, "local_rank"):
raise RuntimeError(
"`get_local_rank` can only be called for TrainSession! "
"Make sure you only use that in `train_loop_per_worker` function"
Expand Down Expand Up @@ -314,7 +313,7 @@ def get_local_world_size() -> int:
>>> trainer.fit() # doctest: +SKIP
"""
session = _get_session()
if not isinstance(session, _TrainSessionImpl):
if not hasattr(session, "local_world_size"):
raise RuntimeError(
"`get_local_world_size` can only be called for TrainSession! "
"Make sure you only use that in `train_loop_per_worker` function"
Expand Down Expand Up @@ -345,7 +344,7 @@ def get_node_rank() -> int:
>>> trainer.fit() # doctest: +SKIP
"""
session = _get_session()
if not isinstance(session, _TrainSessionImpl):
if not hasattr(session, "node_rank"):
raise RuntimeError(
"`get_node_rank` can only be called for TrainSession! "
"Make sure you only use that in `train_loop_per_worker` function"
Expand Down Expand Up @@ -397,7 +396,7 @@ def train_loop_per_worker():
If no dataset is passed into Trainer, then return None.
"""
session = _get_session()
if not isinstance(session, _TrainSessionImpl):
if not hasattr(session, "get_dataset_shard"):
raise RuntimeError(
"`get_dataset_shard` can only be called for TrainSession! "
"Make sure you only use that in `train_loop_per_worker` function"
Expand Down
45 changes: 37 additions & 8 deletions python/ray/train/_internal/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
from enum import Enum, auto
from pathlib import Path
import shutil
from typing import Callable, Dict, Optional, Type, Union
from typing import TYPE_CHECKING, Callable, Dict, Optional, Type, Union
import warnings

import ray
from ray.air._internal.util import StartTraceback, RunnerThread
Expand All @@ -35,11 +36,15 @@
)

from ray.train.error import SessionMisuseError
from ray.train.session import _TrainSessionImpl
from ray.util.annotations import DeveloperAPI
from ray.util.debug import log_once


if TYPE_CHECKING:
from ray.data import DataIterator
from ray.tune.execution.placement_groups import PlacementGroupFactory


_INDEX_FILE_EXTENSION = ".files"
_INDEX_FILE = ".RANK_{0}" + _INDEX_FILE_EXTENSION

Expand Down Expand Up @@ -439,22 +444,48 @@ def report(self, metrics: Dict, checkpoint: Optional[Checkpoint] = None) -> None
self.checkpoint(checkpoint)
self._report_legacy(**metrics)

@property
def trial_resources(self) -> "PlacementGroupFactory":
return self.trial_info.resources

@property
def trial_dir(self) -> str:
return self.trial_info.logdir

def get_dataset_shard(
self,
dataset_name: Optional[str] = None,
) -> Optional["DataIterator"]:
shard = self.dataset_shard
if shard is None:
warnings.warn(
"No dataset passed in. Returning None. Make sure to "
"pass in a Dataset to Trainer.run to use this "
"function."
)
elif isinstance(shard, dict):
if not dataset_name:
raise RuntimeError(
"Multiple datasets were passed into ``Trainer``, "
"but no ``dataset_name`` is passed into "
"``get_dataset_shard``. Please specify which "
"dataset shard to retrieve."
)
return shard.get(dataset_name)
return shard


_session: Optional[_TrainSession] = None
# V2 Session API
_session_v2: Optional[_TrainSessionImpl] = None


def init_session(*args, **kwargs) -> None:
global _session
global _session_v2
if _session:
raise ValueError(
"A Train session is already in use. Do not call "
"`init_session()` manually."
)
_session = _TrainSession(*args, **kwargs)
_session_v2 = _TrainSessionImpl(session=_session)


def get_session() -> Optional[_TrainSession]:
Expand All @@ -464,9 +495,7 @@ def get_session() -> Optional[_TrainSession]:
def shutdown_session():
"""Shuts down the initialized session."""
global _session
global _session_v2
_session = None
_session_v2 = None


def _raise_accelerator_session_misuse():
Expand Down
96 changes: 0 additions & 96 deletions python/ray/train/session.py
Original file line number Diff line number Diff line change
@@ -1,96 +0,0 @@
import warnings
from typing import TYPE_CHECKING, Dict, Optional

from ray.air._internal.session import Session
from ray.air.checkpoint import Checkpoint

if TYPE_CHECKING:
# avoid circular import
from ray.data import DataIterator
from ray.train._internal.session import _TrainSession
from ray.tune.execution.placement_groups import PlacementGroupFactory


class _TrainSessionImpl(Session):
"""Session client that "per worker train loop" can interact with.
Notice that each worker will automatically switch to its working
directory on entering the train loop. This is to ensure that
each worker can safely write to a local directory without racing
and overwriting each other."""

def __init__(self, session: "_TrainSession"):
self._session = session

def report(self, metrics: Dict, *, checkpoint: Optional[Checkpoint] = None) -> None:
self._session.report(metrics, checkpoint)

@property
def loaded_checkpoint(self) -> Optional[Checkpoint]:
ckpt = self._session.loaded_checkpoint
if ckpt:
# The new API should only interact with Checkpoint object.
assert isinstance(ckpt, Checkpoint)
return ckpt

@property
def experiment_name(self) -> str:
return self._session.trial_info.experiment_name

@property
def trial_name(self) -> str:
return self._session.trial_info.name

@property
def trial_id(self) -> str:
return self._session.trial_info.id

@property
def trial_resources(self) -> "PlacementGroupFactory":
return self._session.trial_info.resources

@property
def trial_dir(self) -> str:
return self._session.trial_info.logdir

@property
def world_size(self) -> int:
return self._session.world_size

@property
def world_rank(self) -> int:
return self._session.world_rank

@property
def local_rank(self) -> int:
return self._session.local_rank

@property
def local_world_size(self) -> int:
return self._session.local_world_size

@property
def node_rank(self) -> int:
return self._session.node_rank

def get_dataset_shard(
self,
dataset_name: Optional[str] = None,
) -> Optional["DataIterator"]:
shard = self._session.dataset_shard
if shard is None:
warnings.warn(
"No dataset passed in. Returning None. Make sure to "
"pass in a Dataset to Trainer.run to use this "
"function."
)
elif isinstance(shard, dict):
if not dataset_name:
raise RuntimeError(
"Multiple datasets were passed into ``Trainer``, "
"but no ``dataset_name`` is passed into "
"``get_dataset_shard``. Please specify which "
"dataset shard to retrieve."
)
return shard.get(dataset_name)
return shard
2 changes: 0 additions & 2 deletions python/ray/tune/trainable/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from ray.tune.trainable.trainable import Trainable
from ray.tune.trainable.util import TrainableUtil, with_parameters
from ray.tune.trainable.session import Session
from ray.tune.trainable.function_trainable import (
FunctionTrainable,
FuncCheckpointUtil,
Expand All @@ -11,7 +10,6 @@
__all__ = [
"Trainable",
"TrainableUtil",
"Session",
"FunctionTrainable",
"FuncCheckpointUtil",
"with_parameters",
Expand Down
5 changes: 5 additions & 0 deletions python/ray/tune/trainable/function_trainable.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,11 @@ def trial_resources(self):
"""Resources assigned to the trial of this Trainable."""
return self._trial_resources

@property
def trial_dir(self) -> str:
"""Trial-level log directory for the corresponding trial."""
return self._logdir


@DeveloperAPI
class FunctionTrainable(Trainable):
Expand Down
Loading

0 comments on commit e5be08d

Please sign in to comment.