Skip to content

Commit

Permalink
Merge pull request #5404 from nabenabe0928/enhance/allow-users-to-mod…
Browse files Browse the repository at this point in the history
…ify-categorical-distance

Allow users to modify categorical distance more easily
  • Loading branch information
eukaryo committed Apr 25, 2024
2 parents 34c24d0 + 514f2bb commit 6c082c7
Showing 1 changed file with 9 additions and 2 deletions.
11 changes: 9 additions & 2 deletions optuna/samplers/_tpe/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,8 @@ def __init__(
self._search_space = IntersectionSearchSpace(include_pruned=True)
self._constant_liar = constant_liar
self._constraints_func = constraints_func
# NOTE(nabenabe0928): Users can overwrite _ParzenEstimator to customize the TPE behavior.
self._parzen_estimator_cls = _ParzenEstimator

if multivariate:
warnings.warn(
Expand Down Expand Up @@ -514,11 +516,16 @@ def _build_parzen_estimator(
weights_below = _calculate_weights_below_for_multi_objective(
study, trials, self._constraints_func
)[param_mask_below]
mpe = _ParzenEstimator(
mpe = self._parzen_estimator_cls(
observations, search_space, self._parzen_estimator_parameters, weights_below
)
else:
mpe = _ParzenEstimator(observations, search_space, self._parzen_estimator_parameters)
mpe = self._parzen_estimator_cls(
observations, search_space, self._parzen_estimator_parameters
)

if not isinstance(mpe, _ParzenEstimator):
raise RuntimeError("_parzen_estimator_cls must override _ParzenEstimator.")

return mpe

Expand Down

0 comments on commit 6c082c7

Please sign in to comment.