diff --git a/optuna/storages/_cached_storage.py b/optuna/storages/_cached_storage.py index a99fe48d85..daa9c9eb82 100644 --- a/optuna/storages/_cached_storage.py +++ b/optuna/storages/_cached_storage.py @@ -229,7 +229,11 @@ def set_trial_param( if cached_dist: distributions.check_distribution_compatibility(cached_dist, distribution) else: - self._backend._check_or_set_param_distribution( + # On cache miss, check compatibility against previous trials in the database + # and INSERT immediately to prevent other processes from creating incompatible + # ones. By INSERT, it is assumed that no previous entry has been persisted + # already. + self._backend._check_and_set_param_distribution( trial_id, param_name, param_value_internal, distribution ) self._studies[study_id].param_distribution[param_name] = distribution @@ -242,7 +246,7 @@ def set_trial_param( dists[param_name] = distribution cached_trial.distributions = dists - if cached_dist: + if cached_dist: # Already persisted in case of cache miss so no need to update. updates = self._get_updates(trial_id) updates.params[param_name] = param_value_internal updates.distributions[param_name] = distribution diff --git a/optuna/storages/_rdb/storage.py b/optuna/storages/_rdb/storage.py index ff4b811edb..f68136dc4a 100644 --- a/optuna/storages/_rdb/storage.py +++ b/optuna/storages/_rdb/storage.py @@ -774,7 +774,7 @@ def _set_trial_param_without_commit( trial_param.check_and_add(session) - def _check_or_set_param_distribution( + def _check_and_set_param_distribution( self, trial_id: int, param_name: str, @@ -786,6 +786,7 @@ def _check_or_set_param_distribution( # Acquire a lock of this trial. trial = models.TrialModel.find_by_id(trial_id, session, for_update=True) + if trial is None: raise KeyError(models.NOT_FOUND_MSG) @@ -801,15 +802,15 @@ def _check_or_set_param_distribution( distributions.json_to_distribution(previous_record.distribution_json), distribution, ) - else: - session.add( - models.TrialParamModel( - trial_id=trial_id, - param_name=param_name, - param_value=param_value_internal, - distribution_json=distributions.distribution_to_json(distribution), - ) + + session.add( + models.TrialParamModel( + trial_id=trial_id, + param_name=param_name, + param_value=param_value_internal, + distribution_json=distributions.distribution_to_json(distribution), ) + ) # Release lock. session.commit()