Skip to content

Commit

Permalink
Merge pull request #5400 from nabenabe0928/enhance/speed-up-to-intern…
Browse files Browse the repository at this point in the history
…al-repr-in-categorical-dist

Speed up `to_internal_repr` in `CategoricalDistribution`
  • Loading branch information
eukaryo committed May 1, 2024
2 parents c634449 + 25ca4d5 commit 89b0fd1
Showing 1 changed file with 7 additions and 3 deletions.
10 changes: 7 additions & 3 deletions optuna/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,9 +529,13 @@ def to_external_repr(self, param_value_in_internal_repr: float) -> CategoricalCh
return self.choices[int(param_value_in_internal_repr)]

def to_internal_repr(self, param_value_in_external_repr: CategoricalChoiceType) -> float:
for index, choice in enumerate(self.choices):
if _categorical_choice_equal(param_value_in_external_repr, choice):
return index
try:
return self.choices.index(param_value_in_external_repr)
except ValueError: # ValueError: param_value_in_external_repr is not in choices.
# ValueError also happens if external_repr is nan or includes precision error in float.
for index, choice in enumerate(self.choices):
if _categorical_choice_equal(param_value_in_external_repr, choice):
return index

raise ValueError(f"'{param_value_in_external_repr}' not in {self.choices}.")

Expand Down

0 comments on commit 89b0fd1

Please sign in to comment.