Skip to content

Commit

Permalink
[air/output] Add callback hook for trial recovery, only print error t…
Browse files Browse the repository at this point in the history
…able at end (ray-project#37572)

This changes the context-aware output handler so that trial errors are immediately reported with their respective error files. The error table is only printed at the end.

This introduces a new callback hook, `on_trial_recover` which is required so error files are also available in the immediate output when a trial has a transient failure.

Signed-off-by: Kai Fricke <kai@anyscale.com>
Signed-off-by: e428265 <arvind.chandramouli@lmco.com>
  • Loading branch information
krfricke authored and arvind-chandra committed Aug 31, 2023
1 parent d403fc3 commit 8b02500
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 1 deletion.
20 changes: 20 additions & 0 deletions python/ray/tune/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,22 @@ def on_trial_complete(
"""
pass

def on_trial_recover(
self, iteration: int, trials: List["Trial"], trial: "Trial", **info
):
"""Called after a trial instance failed (errored) but the trial is scheduled
for retry.
The search algorithm and scheduler are not notified.
Arguments:
iteration: Number of iterations of the tuning loop.
trials: List of trials.
trial: Trial that just has errored.
**info: Kwargs dict for forward compatibility.
"""
pass

def on_trial_error(
self, iteration: int, trials: List["Trial"], trial: "Trial", **info
):
Expand Down Expand Up @@ -399,6 +415,10 @@ def on_trial_complete(self, **info):
for callback in self._callbacks:
callback.on_trial_complete(**info)

def on_trial_recover(self, **info):
for callback in self._callbacks:
callback.on_trial_recover(**info)

def on_trial_error(self, **info):
for callback in self._callbacks:
callback.on_trial_error(**info)
Expand Down
5 changes: 4 additions & 1 deletion python/ray/tune/execution/trial_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -692,13 +692,16 @@ def _process_trial_failure(
if trial.status == Trial.RUNNING:
if trial.should_recover():
self._try_recover(trial, exc=exception)
self._callbacks.on_trial_recover(
iteration=self._iteration, trials=self._trials, trial=trial
)
else:
self._scheduler_alg.on_trial_error(self, trial)
self._search_alg.on_trial_complete(trial.trial_id, error=True)
self._schedule_trial_stop(trial, exception=exception)
self._callbacks.on_trial_error(
iteration=self._iteration, trials=self._trials, trial=trial
)
self._schedule_trial_stop(trial, exception=exception)

###
# STOP
Expand Down
26 changes: 26 additions & 0 deletions python/ray/tune/experimental/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -739,6 +739,28 @@ def on_trial_complete(
)
self._print_result(trial)

def on_trial_error(
self, iteration: int, trials: List["Trial"], trial: "Trial", **info
):
curr_time_str, running_time_str = _get_time_str(self._start_time, time.time())
finished_iter = 0
if trial.last_result and TRAINING_ITERATION in trial.last_result:
finished_iter = trial.last_result[TRAINING_ITERATION]

self._start_block(f"trial_{trial}_error")
print(
f"{self._addressing_tmpl.format(trial)} "
f"errored after {finished_iter} iterations "
f"at {curr_time_str}. Total running time: {running_time_str}\n"
f"Error file: {trial.error_file}"
)
self._print_result(trial)

def on_trial_recover(
self, iteration: int, trials: List["Trial"], trial: "Trial", **info
):
self.on_trial_error(iteration=iteration, trials=trials, trial=trial, **info)

def on_checkpoint(
self,
iteration: int,
Expand Down Expand Up @@ -970,6 +992,10 @@ def _print_heartbeat(self, trials, *sys_args, force: bool = False):
if more_infos:
print(", ".join(more_infos))

if not force:
# Only print error table at end of training
return

trials_with_error = _get_trials_with_error(trials)
if not trials_with_error:
return
Expand Down

0 comments on commit 8b02500

Please sign in to comment.