Skip to content

Commit

Permalink
Add confidence interval and standard error of the score
Browse files Browse the repository at this point in the history
Closes #84
  • Loading branch information
kiudee committed Aug 24, 2020
1 parent 5d156ad commit 4215024
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 2 deletions.
2 changes: 2 additions & 0 deletions HISTORY.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ History
-------------------------
* Improve default parameters to be slightly more robust for most use cases and
be more in line with what a user might expect.
* Add confidence interval and standard error of the score of the estimated
global optimum to the logging output
* Fix debug output being spammed by other libraries.

0.6.0-beta.0 (2020-08-24)
Expand Down
18 changes: 16 additions & 2 deletions tune/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import numpy as np
from atomicwrites import AtomicWriter
from bask.optimizer import Optimizer
from scipy.special import erfinv
from skopt.utils import create_result

from tune.db_workers import TuningClient, TuningServer
Expand Down Expand Up @@ -353,9 +354,21 @@ def local( # noqa: C901
try:
best_point, best_value = expected_ucb(result_object, alpha=0.0)
best_point_dict = dict(zip(param_ranges.keys(), best_point))
_, best_std = opt.gp.predict(
opt.space.transform([best_point]), return_std=True
)
root_logger.info(f"Current optimum:\n{best_point_dict}")
root_logger.info(f"Estimated value: {best_value}")
root_logger.info(
f"Estimated value: {np.around(best_value, 4)} +- "
f"{np.around(best_std, 4).item()}"
)
confidence_val = settings.get("confidence", confidence)
confidence_mult = erfinv(confidence_val) * np.sqrt(2)
root_logger.info(
f"{confidence_val * 100}% confidence interval of the value: "
f"({np.around(best_value - confidence_mult * best_std, 4).item()}, "
f"{np.around(best_value + confidence_mult * best_std, 4).item()})"
)
confidence_out = confidence_intervals(
optimizer=opt,
param_names=list(param_ranges.keys()),
Expand All @@ -364,7 +377,8 @@ def local( # noqa: C901
multimodal=False,
)
root_logger.info(
f"{confidence_val*100}% confidence intervals:\n{confidence_out}"
f"{confidence_val * 100}% confidence intervals of the parameters:"
f"\n{confidence_out}"
)
except ValueError:
root_logger.info(
Expand Down

1 comment on commit 4215024

@Claes1981
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

Please sign in to comment.