Skip to content

Commit

Permalink
Add TrialStatus.EARLY_STOPPED
Browse files Browse the repository at this point in the history
Summary: As discussed on D28920556 (5e666be), add a new TrialStatus for early stopped trials.

Reviewed By: Balandat

Differential Revision: D28996613

fbshipit-source-id: a22ac8299bbd085cadbf7db7933a2ea4d3853fa4
  • Loading branch information
ldworkin authored and facebook-github-bot committed Jun 10, 2021
1 parent 442b519 commit 73220ba
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 11 deletions.
19 changes: 18 additions & 1 deletion ax/core/base_trial.py
Expand Up @@ -33,12 +33,14 @@ class TrialStatus(int, Enum):
CANDIDATE --> STAGED --> RUNNING --> COMPLETED
-------------> --> FAILED (machine failure)
--> EARLY_STOPPED (deemed unpromising)
-------------------------> ABANDONED (human-initiated action)
Trials may be abandoned at any time prior to completion or failure
via human intervention. The difference between abandonment and failure
is that the former is human-directed, while the latter is an internal
failure state.
failure state. Early-stopped refers to trials that were deemed
unpromising by an early-stopping strategy and therefore terminated.
Additionally, when trials are deployed, they may be in an intermediate
staged state (e.g. scheduled but waiting for resources) or immediately
Expand All @@ -56,6 +58,7 @@ class TrialStatus(int, Enum):
RUNNING = 4
ABANDONED = 5
DISPATCHED = 6 # Deprecated.
EARLY_STOPPED = 7

@property
def is_terminal(self) -> bool:
Expand Down Expand Up @@ -572,6 +575,18 @@ def mark_failed(self) -> BaseTrial:
self._time_completed = datetime.now()
return self

def mark_early_stopped(self) -> BaseTrial:
"""Mark trial as early stopped.
Returns:
The trial instance.
"""
if self._status != TrialStatus.RUNNING:
raise ValueError("Can only early stop trial that is currently running.")
self._status = TrialStatus.EARLY_STOPPED
self._time_completed = datetime.now()
return self

def mark_as(self, status: TrialStatus, **kwargs: Any) -> BaseTrial:
"""Mark trial with a new TrialStatus.
Expand All @@ -594,6 +609,8 @@ def mark_as(self, status: TrialStatus, **kwargs: Any) -> BaseTrial:
self.mark_failed()
elif status == TrialStatus.COMPLETED:
self.mark_completed()
elif status == TrialStatus.EARLY_STOPPED:
self.mark_early_stopped()
else:
raise ValueError(f"Cannot mark trial as {status}.")
return self
Expand Down
8 changes: 8 additions & 0 deletions ax/core/tests/test_batch_trial.py
Expand Up @@ -349,6 +349,14 @@ def testFailedBatchTrial(self):
self.assertEqual(self.batch.status, TrialStatus.FAILED)
self.assertIsNotNone(self.batch.time_completed)

def testEarlyStoppedBatchTrial(self):
self.batch.runner = SyntheticRunner()
self.batch.run()
self.batch.mark_early_stopped()

self.assertEqual(self.batch.status, TrialStatus.EARLY_STOPPED)
self.assertIsNotNone(self.batch.time_completed)

def testAbandonArm(self):
arm = self.batch.arms[0]
reason = "Bad arm"
Expand Down
22 changes: 12 additions & 10 deletions ax/service/tests/test_scheduler.py
Expand Up @@ -71,7 +71,10 @@ def report_results(self) -> Tuple[bool, Dict[str, Set[int]]]:
# will be a pointer and all will be the same
"trials_completed_so_far": set(
self.experiment.trial_indices_by_status[TrialStatus.COMPLETED]
)
),
"trials_early_stopped_so_far": set(
self.experiment.trial_indices_by_status[TrialStatus.EARLY_STOPPED]
),
},
)

Expand Down Expand Up @@ -615,7 +618,7 @@ def poll_trial_status(self):
return {}

def should_stop_trials_early(self, trial_indices: Set[int]):
return {TrialStatus.COMPLETED: trial_indices}
return {TrialStatus.EARLY_STOPPED: trial_indices}

total_trials = 3
scheduler = EarlyStopsInsteadOfNormalCompletionScheduler(
Expand All @@ -641,9 +644,9 @@ def should_stop_trials_early(self, trial_indices: Set[int]):
self.assertEqual(len(res_list), expected_num_polls)
self.assertIsInstance(res_list, list)
# Both trials in first batch of parallelism will be early stopped
self.assertEqual(len(res_list[0]["trials_completed_so_far"]), 2)
self.assertEqual(len(res_list[0]["trials_early_stopped_so_far"]), 2)
# Third trial in second batch of parallelism will be early stopped
self.assertEqual(len(res_list[1]["trials_completed_so_far"]), 3)
self.assertEqual(len(res_list[1]["trials_early_stopped_so_far"]), 3)
self.assertEqual(
mock_should_stop_trials_early.call_count, expected_num_polls
)
Expand Down Expand Up @@ -672,7 +675,7 @@ def should_stop_trials_early(
new_statuses = {}
for trial_index in trial_indices:
if trial_index % 2 == 1:
new_statuses[trial_index] = TrialStatus.COMPLETED
new_statuses[trial_index] = TrialStatus.EARLY_STOPPED
else:
new_statuses[trial_index] = None
return new_statuses
Expand Down Expand Up @@ -707,12 +710,11 @@ def poll_trial_status(self):
)
expected_num_steps = 2
self.assertEqual(len(res_list), expected_num_steps)
# Trial #1 completed in first step
self.assertDictEqual(res_list[0], {"trials_completed_so_far": {1}})
# Trial #1 early stopped in first step
self.assertEqual(res_list[0]["trials_early_stopped_so_far"], {1})
# All trials completed by end of second step
self.assertDictEqual(
res_list[1], {"trials_completed_so_far": set(range(total_trials))}
)
self.assertEqual(res_list[1]["trials_early_stopped_so_far"], {1})
self.assertEqual(res_list[1]["trials_completed_so_far"], {0, 2})
self.assertEqual(mock_stop_trial_runs.call_count, expected_num_steps)

def test_run_trials_in_batches(self):
Expand Down

0 comments on commit 73220ba

Please sign in to comment.