Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/master' into tpe-multivariate-…
Browse files Browse the repository at this point in the history
…all-trials
  • Loading branch information
knshnb committed Nov 6, 2022
2 parents 9b82a11 + bccb63d commit 2251454
Show file tree
Hide file tree
Showing 13 changed files with 213 additions and 50 deletions.
10 changes: 9 additions & 1 deletion .github/workflows/tests-storage.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ jobs:
# RDB. Since current name "tests-rdbstorage" is required in the Branch protection rules, you
# need to modify the Branch protection rules as well.
tests-rdbstorage:
runs-on: ubuntu-latest
runs-on: ubuntu-20.04

strategy:
matrix:
Expand Down Expand Up @@ -117,3 +117,11 @@ jobs:
env:
OMP_NUM_THREADS: 1
TEST_DB_URL: redis://localhost:6379

- name: Tests Journal Redis
run: |
pytest tests/storages_tests/test_with_server.py
env:
OMP_NUM_THREADS: 1
TEST_DB_URL: redis://localhost:6379
TEST_DB_MODE: journal-redis
2 changes: 1 addition & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ concurrency:

jobs:
tests:
runs-on: ubuntu-latest
runs-on: ubuntu-20.04

strategy:
matrix:
Expand Down
2 changes: 1 addition & 1 deletion CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ Optuna depends on Sphinx to build the documentation HTML files from the correspo
but as you may notice, [Tutorial directory](https://github.com/optuna/optuna/tree/master/tutorial) does not have any `.rst` files. Instead, it has a bunch of Python (`.py`) files.
We have [Sphinx Gallery](https://sphinx-gallery.github.io/stable/index.html) that executes those `.py` files and generates `.rst` files with standard outputs from them and corresponding Jupyter Notebook (`.ipynb`) files.
These generated `.rst` and `.ipynb` files are written to the docs/source/tutorial directory.
The output directory (docs/source/tutorial) and source (tutorial) directory are configured in [`sphinx_gallery_conf ` of docs/source/conf.py](https://github.com/optuna/optuna/blob/2e14273cab87f13edeb9d804a43bd63c44703cb5/docs/source/conf.py#L189-L199). These generated `.rst` files are handled by Sphinx like the other `.rst` files. The generated `.ipynb` files are hosted on Optuna’s documentation page and downloadable (check [Optuna tutorial](https://optuna.readthedocs.io/en/stable/tutorial/index.html)).
The output directory (docs/source/tutorial) and source (tutorial) directory are configured in [`sphinx_gallery_conf` of docs/source/conf.py](https://github.com/optuna/optuna/blob/2e14273cab87f13edeb9d804a43bd63c44703cb5/docs/source/conf.py#L189-L199). These generated `.rst` files are handled by Sphinx like the other `.rst` files. The generated `.ipynb` files are hosted on Optuna’s documentation page and downloadable (check [Optuna tutorial](https://optuna.readthedocs.io/en/stable/tutorial/index.html)).

The order of contents on [tutorial top page](https://optuna.readthedocs.io/en/stable/tutorial/index.html) is determined by two keys: one is the subdirectory name of tutorial and the other is the filename (note that there are some alternatives as documented in [Sphinx Gallery - sorting](https://sphinx-gallery.github.io/stable/gen_modules/sphinx_gallery.sorting.html?highlight=filenamesortkey), but we chose this key in https://github.com/optuna/optuna/blob/2e14273cab87f13edeb9d804a43bd63c44703cb5/docs/source/conf.py#L196).
Optuna’s tutorial directory has two directories: (1) [10_key_features](https://github.com/optuna/optuna/tree/master/tutorial/10_key_features), which is meant to be aligned with and explain the key features listed on [README.md](https://github.com/optuna/optuna#key-features) and (2) [20_recipes](https://github.com/optuna/optuna/tree/master/tutorial/20_recipes), whose contents showcase how to use Optuna features conveniently.
Expand Down
1 change: 1 addition & 0 deletions docs/source/reference/storages.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,4 @@ The :mod:`~optuna.storages` module defines a :class:`~optuna.storages.BaseStorag
optuna.storages.JournalFileStorage
optuna.storages.JournalFileSymlinkLock
optuna.storages.JournalFileOpenLock
optuna.storages.JournalRedisStorage
4 changes: 2 additions & 2 deletions optuna/samplers/_tpe/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -692,7 +692,7 @@ def _split_observation_pairs(
# 3. Feasible trials are sorted by loss_vals.
if violations is not None:
violation_1d = np.array(violations, dtype=float)
idx = violation_1d.argsort()
idx = violation_1d.argsort(kind="stable")
if n_below >= len(idx) or violation_1d[idx[n_below]] > 0:
# Below is filled by all feasible trials and trials with smaller violation values.
indices_below = idx[:n_below]
Expand Down Expand Up @@ -720,7 +720,7 @@ def _split_observation_pairs(
[(s, v[0]) for s, v in loss_vals], dtype=[("step", float), ("score", float)]
)

index_loss_ascending = np.argsort(loss_values)
index_loss_ascending = np.argsort(loss_values, kind="stable")
# `np.sort` is used to keep chronological order.
indices_below = np.sort(index_loss_ascending[:n_below])
indices_above = np.sort(index_loss_ascending[n_below:])
Expand Down
2 changes: 2 additions & 0 deletions optuna/storages/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from optuna.storages._journal.file import JournalFileOpenLock
from optuna.storages._journal.file import JournalFileStorage
from optuna.storages._journal.file import JournalFileSymlinkLock
from optuna.storages._journal.redis import JournalRedisStorage
from optuna.storages._journal.storage import JournalStorage
from optuna.storages._rdb.storage import RDBStorage
from optuna.storages._redis import RedisStorage
Expand All @@ -24,6 +25,7 @@
"JournalFileSymlinkLock",
"JournalFileOpenLock",
"JournalFileStorage",
"JournalRedisStorage",
"RetryFailedTrialCallback",
"_CachedStorage",
"fail_stale_trials",
Expand Down
86 changes: 86 additions & 0 deletions optuna/storages/_journal/redis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import json
import time
from typing import Any
from typing import Dict
from typing import List

from optuna._experimental import experimental_class
from optuna._imports import try_import
from optuna.storages._journal.base import BaseJournalLogStorage


with try_import() as _imports:
import redis


@experimental_class("3.1.0")
class JournalRedisStorage(BaseJournalLogStorage):
"""Redis storage class for Journal log backend.
Args:
url:
URL of the redis storage, password and db are optional.
(ie: ``redis://localhost:6379``)
use_cluster:
Flag whether you use the Redis cluster. If this is :obj:`False`, it is assumed that
you use the standalone Redis server and ensured that a write operation is atomic. This
provides the consistency of the preserved logs. If this is :obj:`True`, it is assumed
that you use the Redis cluster and not ensured that a write operation is atomic. This
means the preserved logs can be inconsistent due to network errors, and may
cause errors.
prefix:
Prefix of the preserved key of logs. This is useful when multiple users work on one
Redis server.
"""

def __init__(self, url: str, use_cluster: bool = False, prefix: str = "") -> None:

_imports.check()

self._url = url
self._redis = redis.Redis.from_url(url)
self._use_cluster = use_cluster
self._prefix = prefix

def read_logs(self, log_number_from: int) -> List[Dict[str, Any]]:

max_log_number_bytes = self._redis.get(f"{self._prefix}:log_number")
if max_log_number_bytes is None:
return []
max_log_number = int(max_log_number_bytes)

logs = []
for log_number in range(log_number_from, max_log_number + 1):
sleep_secs = 0.1
while True:
log_bytes = self._redis.get(self._key_log_id(log_number))
if log_bytes is not None:
break
time.sleep(sleep_secs)
sleep_secs = min(sleep_secs * 2, 10)
log = log_bytes.decode("utf-8")
try:
logs.append(json.loads(log))
except json.JSONDecodeError as err:
if log_number != max_log_number:
raise err
return logs

def append_logs(self, logs: List[Dict[str, Any]]) -> None:

self._redis.setnx(f"{self._prefix}:log_number", -1)
for log in logs:
if not self._use_cluster:
self._redis.eval( # type: ignore
"local i = redis.call('incr', string.format('%s:log_number', ARGV[1])) "
"redis.call('set', string.format('%s:log:%d', ARGV[1], i), ARGV[2])",
0,
self._prefix,
json.dumps(log),
)
else:
log_number = self._redis.incr(f"{self._prefix}:log_number", 1)
self._redis.set(self._key_log_id(log_number), json.dumps(log))

def _key_log_id(self, log_number: int) -> str:
return f"{self._prefix}:log:{log_number}"
5 changes: 5 additions & 0 deletions optuna/testing/storages.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
"redis",
"cached_redis",
"journal",
"journal_redis",
]

STORAGE_MODES_HEARTBEAT = [
Expand Down Expand Up @@ -64,6 +65,10 @@ def __enter__(
if "cached" in self.storage_specifier
else rdb_storage
)
elif self.storage_specifier == "journal_redis":
journal_redis_storage = optuna.storages.JournalRedisStorage("redis://localhost")
journal_redis_storage._redis = fakeredis.FakeStrictRedis()
return optuna.storages.JournalStorage(journal_redis_storage)
elif "redis" in self.storage_specifier:
redis_storage = optuna.storages.RedisStorage("redis://localhost", **self.extra_args)
redis_storage._redis = fakeredis.FakeStrictRedis()
Expand Down
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,8 @@ def get_extras_require() -> Dict[str, List[str]]:
],
"test": [
"codecov",
"fakeredis<=1.7.1; python_version<'3.7'",
"fakeredis ; python_version>='3.7'",
"fakeredis[lua]<=1.7.1; python_version<'3.7'",
"fakeredis[lua] ; python_version>='3.7'",
"kaleido",
"pytest",
],
Expand Down
6 changes: 4 additions & 2 deletions tests/samplers_tests/test_samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -942,8 +942,10 @@ def objective(trial: Trial, low: int, high: int) -> float:
assert all(t.state == TrialState.COMPLETE for t in study.trials)


# We add tests for constant objective functions to ensure the reproducibility of sorting.
@parametrize_sampler_with_seed
def test_reproducible(sampler_class: Callable[[int], BaseSampler]) -> None:
@pytest.mark.parametrize("objective_func", [lambda *args: sum(args), lambda *args: 0.0])
def test_reproducible(sampler_class: Callable[[int], BaseSampler], objective_func: Any) -> None:
def objective(trial: Trial) -> float:
a = trial.suggest_float("a", 1, 9)
b = trial.suggest_float("b", 1, 9, log=True)
Expand All @@ -952,7 +954,7 @@ def objective(trial: Trial) -> float:
e = trial.suggest_int("e", 1, 9, log=True)
f = trial.suggest_int("f", 1, 9, step=2)
g = cast(int, trial.suggest_categorical("g", range(1, 10)))
return a + b + c + d + e + f + g
return objective_func(a, b, c, d, e, f, g)

study = optuna.create_study(sampler=sampler_class(1))
study.optimize(objective, n_trials=20)
Expand Down
24 changes: 22 additions & 2 deletions tests/storages_tests/test_journal.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from typing import Optional
from typing import Type

import fakeredis
import pytest

import optuna
Expand All @@ -17,6 +18,8 @@
LOG_STORAGE = {
"file_with_open_lock",
"file_with_link_lock",
"redis_default",
"redis_with_use_cluster",
}


Expand All @@ -25,7 +28,7 @@ def __init__(self, storage_type: str) -> None:
self.storage_type = storage_type
self.tempfile: Optional[IO[Any]] = None

def __enter__(self) -> optuna.storages.JournalFileStorage:
def __enter__(self) -> optuna.storages.BaseJournalLogStorage:
if self.storage_type.startswith("file"):
self.tempfile = tempfile.NamedTemporaryFile()
lock: JournalFileBaseLock
Expand All @@ -36,6 +39,13 @@ def __enter__(self) -> optuna.storages.JournalFileStorage:
else:
raise Exception("Must not reach here")
return optuna.storages.JournalFileStorage(self.tempfile.name, lock)
elif self.storage_type.startswith("redis"):
use_cluster = self.storage_type == "redis_with_use_cluster"
journal_redis_storage = optuna.storages.JournalRedisStorage(
"redis://localhost", use_cluster
)
journal_redis_storage._redis = fakeredis.FakeStrictRedis()
return journal_redis_storage
else:
raise RuntimeError("Unknown log storage type: {}".format(self.storage_type))

Expand All @@ -48,7 +58,10 @@ def __exit__(


@pytest.mark.parametrize("log_storage_type", LOG_STORAGE)
def test_concurrent_append_logs(log_storage_type: str) -> None:
def test_concurrent_append_logs_for_multi_processes(log_storage_type: str) -> None:
if log_storage_type.startswith("redis"):
pytest.skip("The `fakeredis` does not support multi process environments.")

num_executors = 10
num_records = 200
record = {"key": "value"}
Expand All @@ -60,6 +73,13 @@ def test_concurrent_append_logs(log_storage_type: str) -> None:
assert len(storage.read_logs(0)) == num_records
assert all(record == r for r in storage.read_logs(0))


@pytest.mark.parametrize("log_storage_type", LOG_STORAGE)
def test_concurrent_append_logs_for_multi_threads(log_storage_type: str) -> None:
num_executors = 10
num_records = 200
record = {"key": "value"}

with JournalLogStorageSupplier(log_storage_type) as storage:
with ThreadPoolExecutor(num_executors) as pool:
pool.map(storage.append_logs, [[record] for _ in range(num_records)], timeout=20)
Expand Down
Loading

0 comments on commit 2251454

Please sign in to comment.