Skip to content

Commit

Permalink
Add more underscores to function names
Browse files Browse the repository at this point in the history
  • Loading branch information
jrbourbeau committed Nov 16, 2020
1 parent 568dfb7 commit 81b838e
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 28 deletions.
59 changes: 31 additions & 28 deletions optuna/integration/dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,15 @@
# Serialization utilities


def serialize_datetime(dt: datetime.datetime) -> dict:
def _serialize_datetime(dt: datetime.datetime) -> dict:
return {"__datetime__": True, "as_str": dt.strftime("%Y%m%dT%H:%M:%S.%f")}


def deserialize_datetime(data: dict) -> datetime.datetime:
def _deserialize_datetime(data: dict) -> datetime.datetime:
return datetime.datetime.strptime(data["as_str"], "%Y%m%dT%H:%M:%S.%f")


def serialize_frozentrial(trial: FrozenTrial) -> dict:
def _serialize_frozentrial(trial: FrozenTrial) -> dict:
data = trial.__dict__.copy()
data["state"] = data["state"].name
for attr in [
Expand All @@ -53,50 +53,50 @@ def serialize_frozentrial(trial: FrozenTrial) -> dict:
data[attr] = data.pop(f"_{attr}")
data["distributions"] = {k: distribution_to_json(v) for k, v in data["distributions"].items()}
if data["datetime_start"] is not None:
data["datetime_start"] = serialize_datetime(data["datetime_start"])
data["datetime_start"] = _serialize_datetime(data["datetime_start"])
if data["datetime_complete"] is not None:
data["datetime_complete"] = serialize_datetime(data["datetime_complete"])
data["datetime_complete"] = _serialize_datetime(data["datetime_complete"])
return data


def deserialize_frozentrial(data: dict) -> FrozenTrial:
def _deserialize_frozentrial(data: dict) -> FrozenTrial:
data["state"] = getattr(TrialState, data["state"])
data["distributions"] = {k: json_to_distribution(v) for k, v in data["distributions"].items()}
if data["datetime_start"] is not None:
data["datetime_start"] = deserialize_datetime(data["datetime_start"])
data["datetime_start"] = _deserialize_datetime(data["datetime_start"])
if data["datetime_complete"] is not None:
data["datetime_complete"] = deserialize_datetime(data["datetime_complete"])
data["datetime_complete"] = _deserialize_datetime(data["datetime_complete"])
trail = FrozenTrial(**data)
return trail


def serialize_studysummary(summary: StudySummary) -> dict:
def _serialize_studysummary(summary: StudySummary) -> dict:
data = summary.__dict__.copy()
data["study_id"] = data.pop("_study_id")
data["best_trial"] = serialize_frozentrial(data["best_trial"])
data["datetime_start"] = serialize_datetime(data["datetime_start"])
data["best_trial"] = _serialize_frozentrial(data["best_trial"])
data["datetime_start"] = _serialize_datetime(data["datetime_start"])
data["direction"] = data["direction"]["name"]
return data


def deserialize_studysummary(data: dict) -> StudySummary:
def _deserialize_studysummary(data: dict) -> StudySummary:
data["direction"] = getattr(StudyDirection, data["direction"])
data["best_trial"] = deserialize_frozentrial(data["best_trial"])
data["datetime_start"] = deserialize_datetime(data["datetime_start"])
data["best_trial"] = _deserialize_frozentrial(data["best_trial"])
data["datetime_start"] = _deserialize_datetime(data["datetime_start"])
summary = StudySummary(**data)
return summary


def serialize_studydirection(direction: StudyDirection) -> str:
def _serialize_studydirection(direction: StudyDirection) -> str:
return direction.name


def deserialize_studydirection(data: str) -> StudyDirection:
def _deserialize_studydirection(data: str) -> StudyDirection:
return getattr(StudyDirection, data)


class _OptunaSchedulerExtension:
def __init__(self, scheduler: distributed.Scheduler):
def __init__(self, scheduler: "distributed.Scheduler"):
self.scheduler = scheduler
self.storages: Dict[str, BaseStorage] = {}

Expand Down Expand Up @@ -186,7 +186,7 @@ def set_study_direction(
) -> None:
return self.get_storage(storage_name).set_study_direction(
study_id=study_id,
direction=deserialize_studydirection(direction),
direction=_deserialize_studydirection(direction),
)

def get_study_id_from_name(
Expand Down Expand Up @@ -220,7 +220,7 @@ def get_study_direction(
study_id: int,
) -> str:
direction = self.get_storage(storage_name).get_study_direction(study_id=study_id)
return serialize_studydirection(direction)
return _serialize_studydirection(direction)

def get_study_user_attrs(
self,
Expand All @@ -242,7 +242,7 @@ def get_all_study_summaries(
self, comm: "distributed.comm.tcp.TCP", storage_name: str
) -> List[dict]:
summaries = self.get_storage(storage_name).get_all_study_summaries()
return [serialize_studysummary(s) for s in summaries]
return [_serialize_studysummary(s) for s in summaries]

def create_new_trial(
self,
Expand Down Expand Up @@ -365,7 +365,7 @@ def get_trial(
trial_id: int,
) -> dict:
trial = self.get_storage(storage_name).get_trial(trial_id=trial_id)
return serialize_frozentrial(trial)
return _serialize_frozentrial(trial)

def get_all_trials(
self,
Expand All @@ -378,7 +378,7 @@ def get_all_trials(
study_id=study_id,
deepcopy=deepcopy,
)
return [serialize_frozentrial(t) for t in trials]
return [_serialize_frozentrial(t) for t in trials]

def get_n_trials(
self,
Expand All @@ -402,7 +402,7 @@ def read_trials_from_remote_storage(


def _register_with_scheduler(
dask_scheduler: distributed.Scheduler, storage: str, name: str
dask_scheduler: "distributed.Scheduler", storage: str, name: str
) -> None:
if "optuna" not in dask_scheduler.extensions:
ext = _OptunaSchedulerExtension(dask_scheduler)
Expand All @@ -418,7 +418,9 @@ def _use_basestorage_doc(func: Callable) -> Callable:
if method is not None:
# Ensure BaseStorage and DaskStorage have the same signature
assert inspect.signature(func) == inspect.signature(method)
func.__doc__ = method.__doc__
# Overwrite docstring if one does not exist already
if func.__doc__ is not None:
func.__doc__ = method.__doc__
return func


Expand Down Expand Up @@ -567,7 +569,7 @@ def get_study_direction(self, study_id: int) -> StudyDirection:
storage_name=self.name,
study_id=study_id,
)
return deserialize_studydirection(direction)
return _deserialize_studydirection(direction)

@_use_basestorage_doc
def get_study_user_attrs(self, study_id: int) -> Dict[str, Any]:
Expand All @@ -589,7 +591,7 @@ async def _get_all_study_summaries(self) -> List[StudySummary]:
serialized_summaries = await self.client.scheduler.optuna_get_all_study_summaries(
storage_name=self.name
)
return [deserialize_studysummary(s) for s in serialized_summaries]
return [_deserialize_studysummary(s) for s in serialized_summaries]

@_use_basestorage_doc
def get_all_study_summaries(self) -> List[StudySummary]:
Expand Down Expand Up @@ -691,11 +693,12 @@ def set_trial_system_attr(self, trial_id: int, key: str, value: Any) -> None:
)

# Basic trial access

async def _get_trial(self, trial_id: int) -> FrozenTrial:
serialized_trial = await self.client.scheduler.optuna_get_trial(
trial_id=trial_id, storage_name=self.name
)
return deserialize_frozentrial(serialized_trial)
return _deserialize_frozentrial(serialized_trial)

@_use_basestorage_doc
def get_trial(self, trial_id: int) -> FrozenTrial:
Expand All @@ -707,7 +710,7 @@ async def _get_all_trials(self, study_id: int, deepcopy: bool = True) -> List[Fr
study_id=study_id,
deepcopy=deepcopy,
)
return [deserialize_frozentrial(t) for t in serialized_trials]
return [_deserialize_frozentrial(t) for t in serialized_trials]

@_use_basestorage_doc
def get_all_trials(self, study_id: int, deepcopy: bool = True) -> List[FrozenTrial]:
Expand Down
6 changes: 6 additions & 0 deletions tests/integration_tests/test_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,12 @@ def objective(trial: Trial) -> float:
return (x - 2) ** 2


@gen_cluster(client=True)
async def test_experimental(c: Client, s: Scheduler, a: Worker, b: Worker) -> None:
with pytest.warns(optuna._experimental.ExperimentalWarning):
DaskStorage()


@gen_cluster(client=True)
async def test_daskstorage_registers_extension(
c: Client, s: Scheduler, a: Worker, b: Worker
Expand Down

0 comments on commit 81b838e

Please sign in to comment.