From c1ba2229bfb3b0de7b391474c15910a9115c270b Mon Sep 17 00:00:00 2001 From: hvy Date: Thu, 9 Jul 2020 14:50:35 +0900 Subject: [PATCH 1/2] Fix CachedStorage skipping trial param row insertion --- optuna/storages/_cached_storage.py | 11 ++++------- optuna/storages/_rdb/storage.py | 24 ++++-------------------- 2 files changed, 8 insertions(+), 27 deletions(-) diff --git a/optuna/storages/_cached_storage.py b/optuna/storages/_cached_storage.py index a99fe48d85..f04136b2c1 100644 --- a/optuna/storages/_cached_storage.py +++ b/optuna/storages/_cached_storage.py @@ -229,9 +229,7 @@ def set_trial_param( if cached_dist: distributions.check_distribution_compatibility(cached_dist, distribution) else: - self._backend._check_or_set_param_distribution( - trial_id, param_name, param_value_internal, distribution - ) + self._backend._check_param_distribution(trial_id, param_name, distribution) self._studies[study_id].param_distribution[param_name] = distribution params = copy.copy(cached_trial.params) @@ -242,10 +240,9 @@ def set_trial_param( dists[param_name] = distribution cached_trial.distributions = dists - if cached_dist: - updates = self._get_updates(trial_id) - updates.params[param_name] = param_value_internal - updates.distributions[param_name] = distribution + updates = self._get_updates(trial_id) + updates.params[param_name] = param_value_internal + updates.distributions[param_name] = distribution return self._backend.set_trial_param(trial_id, param_name, param_value_internal, distribution) diff --git a/optuna/storages/_rdb/storage.py b/optuna/storages/_rdb/storage.py index dd60cf2e38..1fb48b8674 100644 --- a/optuna/storages/_rdb/storage.py +++ b/optuna/storages/_rdb/storage.py @@ -750,18 +750,13 @@ def _set_trial_param_without_commit( trial_param.check_and_add(session) - def _check_or_set_param_distribution( - self, - trial_id: int, - param_name: str, - param_value_internal: float, - distribution: distributions.BaseDistribution, + def _check_param_distribution( + self, trial_id: int, param_name: str, distribution: distributions.BaseDistribution, ) -> None: session = self.scoped_session() - # Acquire a lock of this trial. - trial = models.TrialModel.find_by_id(trial_id, session, for_update=True) + trial = models.TrialModel.find_by_id(trial_id, session) if trial is None: raise KeyError(models.NOT_FOUND_MSG) @@ -772,23 +767,12 @@ def _check_or_set_param_distribution( .filter(models.TrialParamModel.param_name == param_name) .first() ) + if previous_record is not None: distributions.check_distribution_compatibility( distributions.json_to_distribution(previous_record.distribution_json), distribution, ) - else: - session.add( - models.TrialParamModel( - trial_id=trial_id, - param_name=param_name, - param_value=param_value_internal, - distribution_json=distributions.distribution_to_json(distribution), - ) - ) - - # Release lock. - session.commit() def get_trial_param(self, trial_id, param_name): # type: (int, str) -> float From f51e87419c23726dabab2569a74e0739b7f39585 Mon Sep 17 00:00:00 2001 From: hvy Date: Thu, 9 Jul 2020 18:39:04 +0900 Subject: [PATCH 2/2] Fix distribution compatibility check --- optuna/storages/_cached_storage.py | 15 +++++++++++---- optuna/storages/_rdb/storage.py | 25 +++++++++++++++++++++---- 2 files changed, 32 insertions(+), 8 deletions(-) diff --git a/optuna/storages/_cached_storage.py b/optuna/storages/_cached_storage.py index f04136b2c1..daa9c9eb82 100644 --- a/optuna/storages/_cached_storage.py +++ b/optuna/storages/_cached_storage.py @@ -229,7 +229,13 @@ def set_trial_param( if cached_dist: distributions.check_distribution_compatibility(cached_dist, distribution) else: - self._backend._check_param_distribution(trial_id, param_name, distribution) + # On cache miss, check compatibility against previous trials in the database + # and INSERT immediately to prevent other processes from creating incompatible + # ones. By INSERT, it is assumed that no previous entry has been persisted + # already. + self._backend._check_and_set_param_distribution( + trial_id, param_name, param_value_internal, distribution + ) self._studies[study_id].param_distribution[param_name] = distribution params = copy.copy(cached_trial.params) @@ -240,9 +246,10 @@ def set_trial_param( dists[param_name] = distribution cached_trial.distributions = dists - updates = self._get_updates(trial_id) - updates.params[param_name] = param_value_internal - updates.distributions[param_name] = distribution + if cached_dist: # Already persisted in case of cache miss so no need to update. + updates = self._get_updates(trial_id) + updates.params[param_name] = param_value_internal + updates.distributions[param_name] = distribution return self._backend.set_trial_param(trial_id, param_name, param_value_internal, distribution) diff --git a/optuna/storages/_rdb/storage.py b/optuna/storages/_rdb/storage.py index 1fb48b8674..c7f9f9f97b 100644 --- a/optuna/storages/_rdb/storage.py +++ b/optuna/storages/_rdb/storage.py @@ -750,13 +750,19 @@ def _set_trial_param_without_commit( trial_param.check_and_add(session) - def _check_param_distribution( - self, trial_id: int, param_name: str, distribution: distributions.BaseDistribution, + def _check_and_set_param_distribution( + self, + trial_id: int, + param_name: str, + param_value_internal: float, + distribution: distributions.BaseDistribution, ) -> None: session = self.scoped_session() - trial = models.TrialModel.find_by_id(trial_id, session) + # Acquire a lock of this trial. + trial = models.TrialModel.find_by_id(trial_id, session, for_update=True) + if trial is None: raise KeyError(models.NOT_FOUND_MSG) @@ -767,13 +773,24 @@ def _check_param_distribution( .filter(models.TrialParamModel.param_name == param_name) .first() ) - if previous_record is not None: distributions.check_distribution_compatibility( distributions.json_to_distribution(previous_record.distribution_json), distribution, ) + session.add( + models.TrialParamModel( + trial_id=trial_id, + param_name=param_name, + param_value=param_value_internal, + distribution_json=distributions.distribution_to_json(distribution), + ) + ) + + # Release lock. + session.commit() + def get_trial_param(self, trial_id, param_name): # type: (int, str) -> float