Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix a concurrency bug of JournalStorage set_trial_state_values. #4033

Merged
merged 12 commits into from
Oct 24, 2022
74 changes: 35 additions & 39 deletions optuna/storages/_journal/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,17 +87,17 @@ def objective(trial):
"""

def __init__(self, log_storage: BaseJournalLogStorage) -> None:
self._pid = str(uuid.uuid4())

self._worker_id_prefix = str(uuid.uuid4()) + "-"
c-bata marked this conversation as resolved.
Show resolved Hide resolved
self._backend = log_storage
self._thread_lock = threading.Lock()
self._replay_result = JournalStorageReplayResult(self._pid)
self._replay_result = JournalStorageReplayResult(self._worker_id_prefix)

with self._thread_lock:
self._sync_with_backend()

def _write_log(self, op_code: int, extra_fields: Dict[str, Any]) -> None:
self._backend.append_logs([{"op_code": op_code, "pid": self._pid, **extra_fields}])
worker_id = self._worker_id_prefix + str(threading.get_ident())
self._backend.append_logs([{"op_code": op_code, "worker_id": worker_id, **extra_fields}])

def _sync_with_backend(self) -> None:
logs = self._backend.read_logs(self._replay_result.log_number_read)
Expand Down Expand Up @@ -263,7 +263,8 @@ def set_trial_state_values(

if (
state == TrialState.RUNNING
and trial_id not in self._replay_result._trial_ids_owned_by_this_process
and trial_id
!= self._replay_result._thread_id_to_owned_trial_id.get(threading.get_ident())
):
return False
else:
Expand Down Expand Up @@ -322,16 +323,16 @@ def get_all_trials(


class JournalStorageReplayResult:
def __init__(self, pid: str) -> None:
def __init__(self, worker_id_prefix: str) -> None:
self.log_number_read = 0
self._pid = pid
self._worker_id_prefix = worker_id_prefix
self._studies: Dict[int, FrozenStudy] = {}
self._trials: Dict[int, FrozenTrial] = {}

self._study_id_to_trial_ids: Dict[int, List[int]] = {}
self._trial_id_to_study_id: Dict[int, int] = {}
self._next_study_id: int = 0
self._trial_ids_owned_by_this_process: List[int] = []
self._thread_id_to_owned_trial_id: Dict[int, int] = {}

def apply_logs(self, logs: List[Dict[str, Any]]) -> None:
for log in logs:
Expand Down Expand Up @@ -388,29 +389,27 @@ def get_all_trials(
frozen_trials.append(trial)
return frozen_trials

def _raise_if_log_issued_by_pid(self, log: Dict[str, Any], err: Exception) -> None:
if log["pid"] == self._pid:
raise err
def _is_issued_by_this_worker(self, log: Dict[str, Any]) -> bool:
return log["worker_id"] == self._worker_id_prefix + str(threading.get_ident())

def _study_exists(self, study_id: int, log: Dict[str, Any]) -> bool:
if study_id not in self._studies:
self._raise_if_log_issued_by_pid(log, KeyError(NOT_FOUND_MSG))
return False
return True
if study_id in self._studies:
return True
if self._is_issued_by_this_worker(log):
raise KeyError(NOT_FOUND_MSG)
return False

def _apply_create_study(self, log: Dict[str, Any]) -> None:
study_name = log["study_name"]

if study_name in [s.study_name for s in self._studies.values()]:
self._raise_if_log_issued_by_pid(
log,
DuplicatedStudyError(
if self._is_issued_by_this_worker(log):
raise DuplicatedStudyError(
"Another study with name '{}' already exists. "
"Please specify a different name, or reuse the existing one "
"by setting `load_if_exists` (for Python API) or "
"`--skip-if-exists` flag (for CLI).".format(study_name)
),
)
)
return

study_id = self._next_study_id
Expand Down Expand Up @@ -456,14 +455,12 @@ def _apply_set_study_directions(self, log: Dict[str, Any]) -> None:

current_directions = self._studies[study_id]._directions
if current_directions[0] != StudyDirection.NOT_SET and current_directions != directions:
self._raise_if_log_issued_by_pid(
log,
ValueError(
if self._is_issued_by_this_worker(log):
raise ValueError(
"Cannot overwrite study direction from {} to {}.".format(
current_directions, directions
)
),
)
)
return

self._studies[study_id]._directions = [StudyDirection(d) for d in directions]
Expand Down Expand Up @@ -508,11 +505,10 @@ def _apply_create_trial(self, log: Dict[str, Any]) -> None:
self._study_id_to_trial_ids[study_id].append(trial_id)
self._trial_id_to_study_id[trial_id] = study_id

if log["pid"] == self._pid and self._trials[trial_id].state == TrialState.RUNNING:
self._trial_ids_owned_by_this_process.append(trial_id)

if log["pid"] == self._pid:
if self._is_issued_by_this_worker(log):
self._last_created_trial_id_by_this_process = trial_id
c-bata marked this conversation as resolved.
Show resolved Hide resolved
if self._trials[trial_id].state == TrialState.RUNNING:
self._thread_id_to_owned_trial_id[threading.get_ident()] = trial_id

def _apply_set_trial_param(self, log: Dict[str, Any]) -> None:
trial_id = log["trial_id"]
Expand All @@ -533,8 +529,9 @@ def _apply_set_trial_param(self, log: Dict[str, Any]) -> None:
check_distribution_compatibility(
prev_trial.distributions[param_name], distribution
)
except Exception as e:
self._raise_if_log_issued_by_pid(log, e)
except Exception:
if self._is_issued_by_this_worker(log):
raise
return
break

Expand All @@ -559,15 +556,15 @@ def _apply_set_trial_state_values(self, log: Dict[str, Any]) -> None:
trial = copy.copy(self._trials[trial_id])
if state == TrialState.RUNNING:
trial.datetime_start = datetime_from_isoformat(log["datetime_start"])
self._trial_ids_owned_by_this_process.append(trial_id)
if self._is_issued_by_this_worker(log):
self._thread_id_to_owned_trial_id[threading.get_ident()] = trial_id
if state.is_finished():
trial.datetime_complete = datetime_from_isoformat(log["datetime_complete"])
trial.state = state
if log["values"] is not None:
trial.values = log["values"]

self._trials[trial_id] = trial
return

def _apply_set_trial_intermediate_value(self, log: Dict[str, Any]) -> None:
trial_id = log["trial_id"]
Expand Down Expand Up @@ -603,17 +600,16 @@ def _apply_set_trial_system_attr(self, log: Dict[str, Any]) -> None:

def _trial_exists_and_updatable(self, trial_id: int, log: Dict[str, Any]) -> bool:
if trial_id not in self._trials:
self._raise_if_log_issued_by_pid(log, KeyError(NOT_FOUND_MSG))
if self._is_issued_by_this_worker(log):
raise KeyError(NOT_FOUND_MSG)
return False
elif self._trials[trial_id].state.is_finished():
self._raise_if_log_issued_by_pid(
log,
RuntimeError(
if self._is_issued_by_this_worker(log):
raise RuntimeError(
"Trial#{} has already finished and can not be updated.".format(
self._trials[trial_id].number
)
),
)
)
return False
else:
return True
31 changes: 31 additions & 0 deletions tests/storages_tests/test_journal.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from concurrent.futures import as_completed
from concurrent.futures import ProcessPoolExecutor
from concurrent.futures import ThreadPoolExecutor
import tempfile
Expand Down Expand Up @@ -65,3 +66,33 @@ def test_concurrent_append_logs(log_storage_type: str) -> None:

assert len(storage.read_logs(0)) == num_records
assert all(record == r for r in storage.read_logs(0))


def pop_waiting_trial(file_path: str, study_name: str) -> Optional[int]:
file_storage = optuna.storages.JournalFileStorage(file_path)
storage = optuna.storages.JournalStorage(file_storage)
study = optuna.load_study(storage=storage, study_name=study_name)
return study._pop_waiting_trial_id()


def test_pop_waiting_trial_multiprocess_safe() -> None:
with tempfile.NamedTemporaryFile() as file:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could I ask one thing? Is there any reason why we can't execute multi-process test for other dbs such as sqlite3?

Copy link
Member Author

@c-bata c-bata Oct 17, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

STORAGE_MODES variable, which is used with pytest.mark.parametrize(), contains sqlite3, redis, and inmemory except for journal.

  • sqlite3: To pop waiting trials concurrently, RDBStorage uses SELECT ... FOR UPDATE syntax, which is unsupported in SQLite3. Furthermore, SQLite3 does not support a high level of concurrency. It sometimes raises "database is locked" errors and it makes tests fragile.
  • redis: We use fakeredis for testing RedisStorage. It doesn't support multiple processes.

So here I put this test case instead of using pytest.mark.parametrize("storage_mode", STORAGE_MODES).

file_storage = optuna.storages.JournalFileStorage(file.name)
storage = optuna.storages.JournalStorage(file_storage)
study = optuna.create_study(storage=storage)
num_enqueued = 10
for i in range(num_enqueued):
study.enqueue_trial({"i": i})

trial_id_set = set()
with ProcessPoolExecutor(10) as pool:
futures = []
for i in range(num_enqueued):
future = pool.submit(pop_waiting_trial, file.name, study.study_name)
futures.append(future)

for future in as_completed(futures):
trial_id = future.result()
if trial_id is not None:
trial_id_set.add(trial_id)
assert len(trial_id_set) == num_enqueued
28 changes: 26 additions & 2 deletions tests/study_tests/test_study.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import concurrent.futures
from concurrent.futures import as_completed
from concurrent.futures import ThreadPoolExecutor
import copy
import itertools
import multiprocessing
Expand Down Expand Up @@ -177,7 +178,7 @@ def objective(t: Trial) -> float:
return t.suggest_float("x", -10, 10)

study = create_study()
with concurrent.futures.ThreadPoolExecutor(max_workers=5) as pool:
with ThreadPoolExecutor(max_workers=5) as pool:
for _ in range(10):
pool.submit(study.optimize, objective, n_trials=10)
assert len(study.trials) == 100
Expand Down Expand Up @@ -1547,3 +1548,26 @@ def objective(trial: Trial) -> float:
study.enqueue_trial(params={"x": 1}, skip_if_exists=False)
summaries = get_all_study_summaries(study._storage, include_best_trial=True)
assert summaries[0].datetime_start is not None


@pytest.mark.parametrize("storage_mode", STORAGE_MODES)
def test_pop_waiting_trial_thread_safe(storage_mode: str) -> None:
if "sqlite" == storage_mode or "cached_sqlite" == storage_mode:
pytest.skip("study._pop_waiting_trial is not thread-safe on SQLite3")

num_enqueued = 10
with StorageSupplier(storage_mode) as storage:
study = create_study(storage=storage)
for i in range(num_enqueued):
study.enqueue_trial({"i": i})

trial_id_set = set()
with ThreadPoolExecutor(10) as pool:
futures = []
for i in range(num_enqueued):
future = pool.submit(study._pop_waiting_trial_id)
futures.append(future)

for future in as_completed(futures):
trial_id_set.add(future.result())
assert len(trial_id_set) == num_enqueued