Skip to content

Commit

Permalink
Merge pull request #5452 from aisha-partha/master
Browse files Browse the repository at this point in the history
Adapt to `__future__.annotations` in `optuna/storages/_rdb/models.py`
  • Loading branch information
not522 committed May 23, 2024
2 parents 1a2dfd4 + 9b073d6 commit f2aa1b5
Showing 1 changed file with 26 additions and 36 deletions.
62 changes: 26 additions & 36 deletions optuna/storages/_rdb/models.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
from __future__ import annotations

import enum
import math
from typing import Any
from typing import List
from typing import Optional
from typing import Tuple

from sqlalchemy import asc
from sqlalchemy import case
Expand Down Expand Up @@ -77,7 +76,7 @@ def find_or_raise_by_id(
return study

@classmethod
def find_by_name(cls, study_name: str, session: orm.Session) -> Optional["StudyModel"]:
def find_by_name(cls, study_name: str, session: orm.Session) -> "StudyModel" | None:
study = session.query(cls).filter(cls.study_name == study_name).one_or_none()

return study
Expand All @@ -104,7 +103,7 @@ class StudyDirectionModel(BaseModel):
)

@classmethod
def where_study_id(cls, study_id: int, session: orm.Session) -> List["StudyDirectionModel"]:
def where_study_id(cls, study_id: int, session: orm.Session) -> list["StudyDirectionModel"]:
return session.query(cls).filter(cls.study_id == study_id).all()


Expand All @@ -123,7 +122,7 @@ class StudyUserAttributeModel(BaseModel):
@classmethod
def find_by_study_and_key(
cls, study: StudyModel, key: str, session: orm.Session
) -> Optional["StudyUserAttributeModel"]:
) -> "StudyUserAttributeModel" | None:
attribute = (
session.query(cls)
.filter(cls.study_id == study.study_id)
Expand All @@ -136,7 +135,7 @@ def find_by_study_and_key(
@classmethod
def where_study_id(
cls, study_id: int, session: orm.Session
) -> List["StudyUserAttributeModel"]:
) -> list["StudyUserAttributeModel"]:
return session.query(cls).filter(cls.study_id == study_id).all()


Expand All @@ -155,7 +154,7 @@ class StudySystemAttributeModel(BaseModel):
@classmethod
def find_by_study_and_key(
cls, study: StudyModel, key: str, session: orm.Session
) -> Optional["StudySystemAttributeModel"]:
) -> "StudySystemAttributeModel" | None:
attribute = (
session.query(cls)
.filter(cls.study_id == study.study_id)
Expand All @@ -168,7 +167,7 @@ def find_by_study_and_key(
@classmethod
def where_study_id(
cls, study_id: int, session: orm.Session
) -> List["StudySystemAttributeModel"]:
) -> list["StudySystemAttributeModel"]:
return session.query(cls).filter(cls.study_id == study_id).all()


Expand Down Expand Up @@ -259,10 +258,7 @@ def find_or_raise_by_id(

@classmethod
def count(
cls,
session: orm.Session,
study: Optional[StudyModel] = None,
state: Optional[TrialState] = None,
cls, session: orm.Session, study: StudyModel | None = None, state: TrialState | None = None
) -> int:
trial_count = session.query(func.count(cls.trial_id))
if study is not None:
Expand Down Expand Up @@ -294,7 +290,7 @@ class TrialUserAttributeModel(BaseModel):
@classmethod
def find_by_trial_and_key(
cls, trial: TrialModel, key: str, session: orm.Session
) -> Optional["TrialUserAttributeModel"]:
) -> "TrialUserAttributeModel" | None:
attribute = (
session.query(cls)
.filter(cls.trial_id == trial.trial_id)
Expand All @@ -307,7 +303,7 @@ def find_by_trial_and_key(
@classmethod
def where_trial_id(
cls, trial_id: int, session: orm.Session
) -> List["TrialUserAttributeModel"]:
) -> list["TrialUserAttributeModel"]:
return session.query(cls).filter(cls.trial_id == trial_id).all()


Expand All @@ -326,7 +322,7 @@ class TrialSystemAttributeModel(BaseModel):
@classmethod
def find_by_trial_and_key(
cls, trial: TrialModel, key: str, session: orm.Session
) -> Optional["TrialSystemAttributeModel"]:
) -> "TrialSystemAttributeModel" | None:
attribute = (
session.query(cls)
.filter(cls.trial_id == trial.trial_id)
Expand All @@ -339,7 +335,7 @@ def find_by_trial_and_key(
@classmethod
def where_trial_id(
cls, trial_id: int, session: orm.Session
) -> List["TrialSystemAttributeModel"]:
) -> list["TrialSystemAttributeModel"]:
return session.query(cls).filter(cls.trial_id == trial_id).all()


Expand Down Expand Up @@ -381,7 +377,7 @@ def _check_compatibility_with_previous_trial_param_distributions(
@classmethod
def find_by_trial_and_param_name(
cls, trial: TrialModel, param_name: str, session: orm.Session
) -> Optional["TrialParamModel"]:
) -> "TrialParamModel" | None:
param_distribution = (
session.query(cls)
.filter(cls.trial_id == trial.trial_id)
Expand All @@ -403,7 +399,7 @@ def find_or_raise_by_trial_and_param_name(
return param_distribution

@classmethod
def where_trial_id(cls, trial_id: int, session: orm.Session) -> List["TrialParamModel"]:
def where_trial_id(cls, trial_id: int, session: orm.Session) -> list["TrialParamModel"]:
trial_params = session.query(cls).filter(cls.trial_id == trial_id).all()

return trial_params
Expand All @@ -428,10 +424,7 @@ class TrialValueType(enum.Enum):
)

@classmethod
def value_to_stored_repr(
cls,
value: float,
) -> Tuple[Optional[float], TrialValueType]:
def value_to_stored_repr(cls, value: float) -> tuple[float | None, TrialValueType]:
if value == float("inf"):
return (None, cls.TrialValueType.INF_POS)
elif value == float("-inf"):
Expand All @@ -440,7 +433,7 @@ def value_to_stored_repr(
return (value, cls.TrialValueType.FINITE)

@classmethod
def stored_repr_to_value(cls, value: Optional[float], float_type: TrialValueType) -> float:
def stored_repr_to_value(cls, value: float | None, float_type: TrialValueType) -> float:
if float_type == cls.TrialValueType.INF_POS:
assert value is None
return float("inf")
Expand All @@ -455,7 +448,7 @@ def stored_repr_to_value(cls, value: Optional[float], float_type: TrialValueType
@classmethod
def find_by_trial_and_objective(
cls, trial: TrialModel, objective: int, session: orm.Session
) -> Optional["TrialValueModel"]:
) -> "TrialValueModel" | None:
trial_value = (
session.query(cls)
.filter(cls.trial_id == trial.trial_id)
Expand All @@ -466,7 +459,7 @@ def find_by_trial_and_objective(
return trial_value

@classmethod
def where_trial_id(cls, trial_id: int, session: orm.Session) -> List["TrialValueModel"]:
def where_trial_id(cls, trial_id: int, session: orm.Session) -> list["TrialValueModel"]:
trial_values = (
session.query(cls).filter(cls.trial_id == trial_id).order_by(asc(cls.objective)).all()
)
Expand Down Expand Up @@ -495,9 +488,8 @@ class TrialIntermediateValueType(enum.Enum):

@classmethod
def intermediate_value_to_stored_repr(
cls,
value: float,
) -> Tuple[Optional[float], TrialIntermediateValueType]:
cls, value: float
) -> tuple[float | None, TrialIntermediateValueType]:
if math.isnan(value):
return (None, cls.TrialIntermediateValueType.NAN)
elif value == float("inf"):
Expand All @@ -509,7 +501,7 @@ def intermediate_value_to_stored_repr(

@classmethod
def stored_repr_to_intermediate_value(
cls, value: Optional[float], float_type: TrialIntermediateValueType
cls, value: float | None, float_type: TrialIntermediateValueType
) -> float:
if float_type == cls.TrialIntermediateValueType.NAN:
assert value is None
Expand All @@ -528,7 +520,7 @@ def stored_repr_to_intermediate_value(
@classmethod
def find_by_trial_and_step(
cls, trial: TrialModel, step: int, session: orm.Session
) -> Optional["TrialIntermediateValueModel"]:
) -> "TrialIntermediateValueModel" | None:
trial_intermediate_value = (
session.query(cls)
.filter(cls.trial_id == trial.trial_id)
Expand All @@ -541,7 +533,7 @@ def find_by_trial_and_step(
@classmethod
def where_trial_id(
cls, trial_id: int, session: orm.Session
) -> List["TrialIntermediateValueModel"]:
) -> list["TrialIntermediateValueModel"]:
trial_intermediate_values = session.query(cls).filter(cls.trial_id == trial_id).all()

return trial_intermediate_values
Expand All @@ -559,9 +551,7 @@ class TrialHeartbeatModel(BaseModel):
)

@classmethod
def where_trial_id(
cls, trial_id: int, session: orm.Session
) -> Optional["TrialHeartbeatModel"]:
def where_trial_id(cls, trial_id: int, session: orm.Session) -> "TrialHeartbeatModel" | None:
return session.query(cls).filter(cls.trial_id == trial_id).one_or_none()


Expand All @@ -574,6 +564,6 @@ class VersionInfoModel(BaseModel):
library_version = _Column(String(MAX_VERSION_LENGTH))

@classmethod
def find(cls, session: orm.Session) -> Optional["VersionInfoModel"]:
def find(cls, session: orm.Session) -> "VersionInfoModel" | None:
version_info = session.query(cls).one_or_none()
return version_info

0 comments on commit f2aa1b5

Please sign in to comment.