Skip to content

Commit

Permalink
Merge pull request #3581 from toshihikoyanase/skip-table-creation
Browse files Browse the repository at this point in the history
Add option to skip table creation to `RDBStorage`.
  • Loading branch information
contramundum53 committed May 24, 2022
2 parents 6dd08af + 37b522a commit 6349240
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 6 deletions.
2 changes: 1 addition & 1 deletion optuna/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -657,7 +657,7 @@ def take_action(self, parsed_args: Namespace) -> None:
if storage_url.startswith("redis"):
self.logger.info("This storage does not support upgrade yet.")
return
storage = RDBStorage(storage_url, skip_compatibility_check=True)
storage = RDBStorage(storage_url, skip_compatibility_check=True, skip_table_creation=True)
current_version = storage.get_current_version()
head_version = storage.get_head_version()
known_versions = storage.get_all_versions()
Expand Down
9 changes: 7 additions & 2 deletions optuna/storages/_rdb/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def objective(trial):
A dictionary of keyword arguments that is passed to
`sqlalchemy.engine.create_engine`_ function.
skip_compatibility_check:
Flag to skip schema compatibility check if set to True.
Flag to skip schema compatibility check if set to :obj:`True`.
heartbeat_interval:
Interval to record the heartbeat. It is recorded every ``interval`` seconds.
``heartbeat_interval`` must be :obj:`None` or a positive integer.
Expand All @@ -156,6 +156,9 @@ def objective(trial):
The procedure to fail existing stale trials is called just before asking the
study for a new trial.
skip_table_creation:
Flag to skip table creation if set to :obj:`True`.
.. _sqlalchemy.engine.create_engine:
https://docs.sqlalchemy.org/en/latest/core/engines.html#sqlalchemy.create_engine
Expand Down Expand Up @@ -192,6 +195,7 @@ def __init__(
heartbeat_interval: Optional[int] = None,
grace_period: Optional[int] = None,
failed_trial_callback: Optional[Callable[["optuna.Study", FrozenTrial], None]] = None,
skip_table_creation: bool = False,
) -> None:

self.engine_kwargs = engine_kwargs or {}
Expand All @@ -218,7 +222,8 @@ def __init__(
self.scoped_session = sqlalchemy_orm.scoped_session(
sqlalchemy_orm.sessionmaker(bind=self.engine)
)
models.BaseModel.metadata.create_all(self.engine)
if not skip_table_creation:
models.BaseModel.metadata.create_all(self.engine)

self._version_manager = _VersionManager(self.url, self.engine, self.scoped_session)
if not skip_compatibility_check:
Expand Down
6 changes: 3 additions & 3 deletions tests/storages_tests/rdb_tests/test_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ def test_upgrade_single_objective_optimization(optuna_version: str) -> None:
shutil.copyfile(src_db_file, f"{workdir}/sqlite.db")
storage_url = f"sqlite:///{workdir}/sqlite.db"

storage = RDBStorage(storage_url, skip_compatibility_check=True)
storage = RDBStorage(storage_url, skip_compatibility_check=True, skip_table_creation=True)
assert storage.get_current_version() == f"v{optuna_version}"
head_version = storage.get_head_version()
storage.upgrade()
Expand Down Expand Up @@ -233,7 +233,7 @@ def test_upgrade_multi_objective_optimization(optuna_version: str) -> None:
shutil.copyfile(src_db_file, f"{workdir}/sqlite.db")
storage_url = f"sqlite:///{workdir}/sqlite.db"

storage = RDBStorage(storage_url, skip_compatibility_check=True)
storage = RDBStorage(storage_url, skip_compatibility_check=True, skip_table_creation=True)
assert storage.get_current_version() == f"v{optuna_version}"
head_version = storage.get_head_version()
storage.upgrade()
Expand Down Expand Up @@ -277,7 +277,7 @@ def test_upgrade_distributions(optuna_version: str) -> None:
shutil.copyfile(src_db_file, f"{workdir}/sqlite.db")
storage_url = f"sqlite:///{workdir}/sqlite.db"

storage = RDBStorage(storage_url, skip_compatibility_check=True)
storage = RDBStorage(storage_url, skip_compatibility_check=True, skip_table_creation=True)
old_study = load_study(storage=storage, study_name="schema migration")
old_distribution_dict = old_study.trials[0].distributions

Expand Down

0 comments on commit 6349240

Please sign in to comment.