Skip to content

Commit

Permalink
Add a test for multi-process and fix a bug
Browse files Browse the repository at this point in the history
  • Loading branch information
c-bata committed Oct 13, 2022
1 parent 6f20830 commit eec3075
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 2 deletions.
5 changes: 3 additions & 2 deletions optuna/storages/_journal/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -559,15 +559,16 @@ def _apply_set_trial_state_values(self, log: Dict[str, Any]) -> None:
trial = copy.copy(self._trials[trial_id])
if state == TrialState.RUNNING:
trial.datetime_start = datetime_from_isoformat(log["datetime_start"])
self._trial_ids[threading.get_ident()] = trial_id
worker_id = self._worker_id_prefix + str(threading.get_ident())
if log["worker_id"] == worker_id:
self._trial_ids[threading.get_ident()] = trial_id
if state.is_finished():
trial.datetime_complete = datetime_from_isoformat(log["datetime_complete"])
trial.state = state
if log["values"] is not None:
trial.values = log["values"]

self._trials[trial_id] = trial
return

def _apply_set_trial_intermediate_value(self, log: Dict[str, Any]) -> None:
trial_id = log["trial_id"]
Expand Down
28 changes: 28 additions & 0 deletions tests/storages_tests/test_journal.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from concurrent.futures import as_completed
from concurrent.futures import ProcessPoolExecutor
from concurrent.futures import ThreadPoolExecutor
import tempfile
Expand Down Expand Up @@ -65,3 +66,30 @@ def test_concurrent_append_logs(log_storage_type: str) -> None:

assert len(storage.read_logs(0)) == num_records
assert all(record == r for r in storage.read_logs(0))


def pop_waiting_trial(file_path: str, study_name: str) -> int:
file_storage = optuna.storages.JournalFileStorage(file_path)
storage = optuna.storages.JournalStorage(file_storage)
study = optuna.load_study(storage=storage, study_name=study_name)
return study._pop_waiting_trial_id()


def test_set_trial_state_values_multiprocess_safe() -> None:
with tempfile.NamedTemporaryFile() as file:
file_storage = optuna.storages.JournalFileStorage(file.name)
storage = optuna.storages.JournalStorage(file_storage)
study = optuna.create_study(storage=storage)
for i in range(10):
study.enqueue_trial({"i": i})

trial_id_set = set()
with ProcessPoolExecutor(10) as pool:
futures = []
for i in range(10):
future = pool.submit(pop_waiting_trial, file.name, study.study_name)
futures.append(future)

for future in as_completed(futures):
trial_id_set.add(future.result())
assert len(trial_id_set) == 10

0 comments on commit eec3075

Please sign in to comment.