Skip to content

Commit

Permalink
[train/tune] Refactor trial metadata organization (ray-project#38165)
Browse files Browse the repository at this point in the history
The Trial object currently keeps properties with different scopes:

1. Static properties that are set **on init**
2. Static properties that are set on init but can be overwritten **on restore**
3. **Temporary** properties that are not saved, e.g. the trial location,
4. **run metadata** that is updated during training, such as the last result, the available checkpoints, etc.

This PR refactors the Trial class to explicitly capture 3) and 4) in sub classes.

Specifically, it introduces a `_TemporaryTrialState` class that contains temporary properties, and a `_TrainingRunMetadata` class that contains run metadata such as the last result, error files, and available checkpoints.

It also changes the way experiment checkpoints are saved. Specifically, we save the trial state (which contains the static properties as well as select runtime metadata (e.g. trial status) separately from the run metadata. This allows us to split these two sources of information in the future.

The changes in this PR mean that loading experiment checkpoints from runs before this change will not be possible. However, this is true for any changes to the `Trial` class. Support for backwards compatibility to resume experiments is only guaranteed on a patch level basis, so these changes should be fine.

A few other improvements have been made (`Trial.runner` is now `Trial.temporary_state.ray_actor`), and a lot of the changes in this PR are changes to tests.

Signed-off-by: Kai Fricke <kai@anyscale.com>
Signed-off-by: harborn <gangsheng.wu@intel.com>
  • Loading branch information
krfricke authored and harborn committed Aug 17, 2023
1 parent 74e333d commit dbebcab
Show file tree
Hide file tree
Showing 26 changed files with 385 additions and 271 deletions.
4 changes: 4 additions & 0 deletions python/ray/air/_internal/checkpoint_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,10 @@ def __init__(

self.set_delete_fn(delete_fn)

@property
def checkpoint_config(self):
return self._checkpoint_strategy

def set_delete_fn(
self, delete_fn: Optional[Callable[["_TrackedCheckpoint"], None]]
):
Expand Down
4 changes: 4 additions & 0 deletions python/ray/train/_internal/checkpoint_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,10 @@ def __init__(self, checkpoint_config: Optional[CheckpointConfig]):
f"{self._checkpoint_config.num_to_keep}"
)

@property
def checkpoint_config(self):
return self._checkpoint_config

def register_checkpoint(self, checkpoint_result: _TrainingResult):
"""Register new checkpoint and add to bookkeeping.
Expand Down
10 changes: 7 additions & 3 deletions python/ray/tune/analysis/experiment_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,11 +212,11 @@ def _load_checkpoints_from_latest(self, latest_checkpoint: List[str]) -> None:
experiment_state = json.load(f, cls=TuneFunctionDecoder)
self._experiment_states.append(experiment_state)

if "checkpoints" not in experiment_state:
if "trial_data" not in experiment_state:
raise TuneError("Experiment state invalid; no checkpoints found.")

self._checkpoints_and_paths += [
(cp, Path(path).parent) for cp in experiment_state["checkpoints"]
(cp, Path(path).parent) for cp in experiment_state["trial_data"]
]

def _maybe_download_experiment_checkpoint(
Expand Down Expand Up @@ -978,9 +978,13 @@ def _get_trial_paths(self) -> List[str]:
"since checkpointing is periodic."
)
self.trials = []
for trial_json_state, path in self._checkpoints_and_paths:
for (
trial_json_state,
trial_run_metadata,
), path in self._checkpoints_and_paths:
try:
trial = Trial.from_json_state(trial_json_state, stub=True)
trial.restore_run_metadata(trial_run_metadata)
# TODO(justinvyu): [handle_moved_storage_path]
if not _use_storage_context():
trial.local_experiment_path = str(path)
Expand Down
54 changes: 29 additions & 25 deletions python/ray/tune/execution/tune_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,7 +446,7 @@ def save_to_dir(self, experiment_dir: Optional[str] = None):
# Get state from trial executor and runner
runner_state = {
# Trials
"checkpoints": list(self._get_trial_checkpoints().values()),
"trial_data": list(self._get_trial_checkpoints().values()),
# Experiment data
"runner_data": self.__getstate__(),
# Metadata
Expand Down Expand Up @@ -526,8 +526,9 @@ def restore_from_dir(self, experiment_dir: Optional[str] = None) -> List[Trial]:

# 3. Load trials
trials = []
for trial_json_state in runner_state["checkpoints"]:
for trial_json_state, trial_runtime_metadata in runner_state["trial_data"]:
trial = Trial.from_json_state(trial_json_state)
trial.restore_run_metadata(trial_runtime_metadata)

# The following properties may be updated on restoration
# Ex: moved local/cloud experiment directory
Expand Down Expand Up @@ -598,13 +599,15 @@ def resume(
trials = self.restore_from_dir()

# Set trial statuses according to the resume configuration
for trial in sorted(trials, key=lambda t: t.last_update_time, reverse=True):
for trial in sorted(
trials, key=lambda t: t.run_metadata.last_result_time, reverse=True
):
trial_to_add = trial
if trial.status == Trial.ERROR:
if resume_errored:
# Keep trial ID on resume
trial_to_add.error_filename = None
trial_to_add.pickled_error_filename = None
trial_to_add.run_metadata.error_filename = None
trial_to_add.run_metadata.pickled_error_filename = None
trial_to_add.set_status(Trial.PENDING)
if not _use_storage_context():
# TODO(justinvyu): Remove this.
Expand Down Expand Up @@ -1112,7 +1115,7 @@ def _maybe_reuse_cached_actor(self, trial: Trial) -> bool:
ray_actor = self._actor_manager._live_actors_to_ray_actors_resources[
cached_actor
][0]
trial.set_runner(ray_actor)
trial.set_ray_actor(ray_actor)

self._schedule_trial_reset(trial, trial.config, trial.experiment_tag)

Expand Down Expand Up @@ -1268,7 +1271,7 @@ def _maybe_cache_trial_actor(self, trial: Trial) -> bool:
tracked_actor = self._trial_to_actor.pop(trial)
self._actor_to_trial.pop(tracked_actor)

trial.set_runner(None)
trial.set_ray_actor(None)

return True

Expand All @@ -1284,7 +1287,7 @@ def _actor_started(self, tracked_actor: TrackedActor, log: str = "STARTED"):
ray_actor = self._actor_manager._live_actors_to_ray_actors_resources[
tracked_actor
][0]
trial.set_runner(ray_actor)
trial.set_ray_actor(ray_actor)

self._callbacks.on_trial_start(
iteration=self._iteration, trials=self._trials, trial=trial
Expand All @@ -1302,7 +1305,7 @@ def _actor_stopped(self, tracked_actor: TrackedActor):
trial = self._actor_to_trial.pop(tracked_actor)
logger.debug(f"Actor STOPPED for trial {trial}: {tracked_actor}")
self._trial_to_actor.pop(trial)
trial.set_runner(None)
trial.set_ray_actor(None)

logger.debug(f"Actor STOPPED: {tracked_actor}")

Expand Down Expand Up @@ -1522,8 +1525,8 @@ def _schedule_trial_stop(self, trial: Trial, exception: Optional[Exception] = No

logger.debug(f"Requesting to STOP actor for trial {trial}")

trial.saving_to = None
trial.restoring_from = None
trial.temporary_state.saving_to = None
trial.temporary_state.restoring_from = None

self._set_trial_status(trial, Trial.ERROR if exception else Trial.TERMINATED)
trial.set_location(_Location())
Expand All @@ -1550,7 +1553,7 @@ def _schedule_trial_stop(self, trial: Trial, exception: Optional[Exception] = No
tracked_actor = self._trial_to_actor.pop(trial)
self._actor_to_trial.pop(tracked_actor)

trial.set_runner(None)
trial.set_ray_actor(None)

self._remove_actor(tracked_actor=tracked_actor)

Expand Down Expand Up @@ -1849,7 +1852,7 @@ def _schedule_trial_save(
# a done=True result from executing a STOP decision
# (which clears all futures) before the save gets processed.
# Keep this in for now while `train` and `save` are 2 separate steps.
trial.saving_to = True
trial.temporary_state.saving_to = True
# TODO(justinvyu): Remove the return value?
return

Expand Down Expand Up @@ -1878,7 +1881,7 @@ def _schedule_trial_save(
storage_mode=storage,
metrics=result,
)
trial.saving_to = checkpoint
trial.temporary_state.saving_to = checkpoint

return checkpoint

Expand All @@ -1901,8 +1904,8 @@ def _on_saving_result(self, trial, checkpoint_value: Union[ray.ObjectRef, str]):
"is being synced from the worker to the head node."
)

if trial.location.hostname and (
trial.location.hostname != get_node_ip_address()
if trial.temporary_state.location.hostname and (
trial.temporary_state.location.hostname != get_node_ip_address()
):
if log_once("tune_head_worker_checkpoint"):
logger.warning(msg)
Expand Down Expand Up @@ -1931,14 +1934,14 @@ def _process_trial_save(
self._checkpoint_manager.on_trial_checkpoint(trial)
self._mark_trial_to_checkpoint(trial)
else:
trial.saving_to.dir_or_data = checkpoint_value
trial.temporary_state.saving_to.dir_or_data = checkpoint_value
self._callbacks.on_checkpoint(
iteration=self._iteration,
trials=self._trials,
trial=trial,
checkpoint=trial.saving_to,
checkpoint=trial.temporary_state.saving_to,
)
trial.on_checkpoint(trial.saving_to)
trial.on_checkpoint(trial.temporary_state.saving_to)
self._checkpoint_manager.on_trial_checkpoint(trial)
if trial.checkpoint.storage_mode != CheckpointStorage.MEMORY:
self._mark_trial_to_checkpoint(trial)
Expand All @@ -1952,7 +1955,7 @@ def _process_trial_save(
"Trial %s: Error handling checkpoint %s", trial, checkpoint_value
)

trial.saving_to = None
trial.temporary_state.saving_to = None
decision = self._cached_trial_decisions.pop(trial.trial_id, None)
if decision and checkpoint_value:
self._queue_decision(trial, decision)
Expand All @@ -1961,21 +1964,22 @@ def _checkpoint_trial_if_needed(self, trial, force=False):
"""Checkpoints trial based off trial.last_result."""
if trial.should_checkpoint() or force:
# Save trial runtime if possible.
if trial.runner:
if trial.temporary_state.ray_actor:
self._schedule_trial_save(trial, storage=CheckpointStorage.PERSISTENT)

###
# RESTORE
def _schedule_trial_restore(self, trial: Trial) -> bool:
if _use_storage_context():
checkpoint_result = trial.checkpoint_manager.latest_checkpoint_result
cpm = trial.run_metadata.checkpoint_manager
checkpoint_result = cpm.latest_checkpoint_result

if not checkpoint_result:
logger.debug(f"Not restoring trial {trial}: No checkpoint found.")
return False

# TODO(justinvyu): Is this really needed?
trial.restoring_from = checkpoint_result
trial.temporary_state.restoring_from = checkpoint_result

method_name = "restore"
args = (checkpoint_result,)
Expand Down Expand Up @@ -2028,7 +2032,7 @@ def _schedule_trial_restore(self, trial: Trial) -> bool:
"storage-based restoration"
)

trial.restoring_from = checkpoint
trial.temporary_state.restoring_from = checkpoint
self._schedule_trial_task(
trial=trial,
method_name=method_name,
Expand Down Expand Up @@ -2067,7 +2071,7 @@ def _try_recover(self, trial: Trial, exc: Union[TuneError, RayTaskError]):
self._cached_trial_decisions.pop(trial.trial_id, None)
# Resetting this, in case that the trial is in saving status when it crashes.
if trial.is_saving:
trial.saving_to = None
trial.temporary_state.saving_to = None
if trial.is_restoring and exc:
exc = _TuneRestoreError(exc)
self._schedule_trial_stop(trial, exception=exc)
Expand Down
Loading

0 comments on commit dbebcab

Please sign in to comment.