Skip to content

Commit

Permalink
Merge pull request #1498 from hvy/fix-cached-storage-param-insertion
Browse files Browse the repository at this point in the history
Fix `CachedStorage` skipping trial param row insertion on cache miss.
  • Loading branch information
toshihikoyanase committed Jul 13, 2020
2 parents 4be9da5 + f51e874 commit 6c74908
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 11 deletions.
8 changes: 6 additions & 2 deletions optuna/storages/_cached_storage.py
Expand Up @@ -229,7 +229,11 @@ def set_trial_param(
if cached_dist:
distributions.check_distribution_compatibility(cached_dist, distribution)
else:
self._backend._check_or_set_param_distribution(
# On cache miss, check compatibility against previous trials in the database
# and INSERT immediately to prevent other processes from creating incompatible
# ones. By INSERT, it is assumed that no previous entry has been persisted
# already.
self._backend._check_and_set_param_distribution(
trial_id, param_name, param_value_internal, distribution
)
self._studies[study_id].param_distribution[param_name] = distribution
Expand All @@ -242,7 +246,7 @@ def set_trial_param(
dists[param_name] = distribution
cached_trial.distributions = dists

if cached_dist:
if cached_dist: # Already persisted in case of cache miss so no need to update.
updates = self._get_updates(trial_id)
updates.params[param_name] = param_value_internal
updates.distributions[param_name] = distribution
Expand Down
19 changes: 10 additions & 9 deletions optuna/storages/_rdb/storage.py
Expand Up @@ -774,7 +774,7 @@ def _set_trial_param_without_commit(

trial_param.check_and_add(session)

def _check_or_set_param_distribution(
def _check_and_set_param_distribution(
self,
trial_id: int,
param_name: str,
Expand All @@ -786,6 +786,7 @@ def _check_or_set_param_distribution(

# Acquire a lock of this trial.
trial = models.TrialModel.find_by_id(trial_id, session, for_update=True)

if trial is None:
raise KeyError(models.NOT_FOUND_MSG)

Expand All @@ -801,15 +802,15 @@ def _check_or_set_param_distribution(
distributions.json_to_distribution(previous_record.distribution_json),
distribution,
)
else:
session.add(
models.TrialParamModel(
trial_id=trial_id,
param_name=param_name,
param_value=param_value_internal,
distribution_json=distributions.distribution_to_json(distribution),
)

session.add(
models.TrialParamModel(
trial_id=trial_id,
param_name=param_name,
param_value=param_value_internal,
distribution_json=distributions.distribution_to_json(distribution),
)
)

# Release lock.
session.commit()
Expand Down

0 comments on commit 6c74908

Please sign in to comment.