forked from scikit-learn/scikit-learn
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
FEA Successive halving for faster parameter search (scikit-learn#13900)
* More flexible grid search interface * added info dict parameter * Put back removed test * renamed info into more_results * Passed grroups as well since we need n_to use get_n_splits(X, y, groups) * port * pep8 * dabl -> sklearn * add _required_parameters * skipping check in rst file if pandas not installed * Update sklearn/model_selection/_search_successive_halving.py Co-Authored-By: Joel Nothman <joel.nothman@gmail.com> * renamed into GridHalvingSearchCV and RandomHalvingSearchCV * Addressed thomas' comments * repr * removed passing group as a parameter to evaluate_candidates * Joels comments * pep8 * reorganized user user guide * renaming * update user guide * remove groups support + pass fit_params * parameter renaming * pep8 * r_i -> resource_iter * fixed r_i issues * examples + removed use of word budget * Added inpute checking tests * added cv_resutlts_ user guide * minor title change * fixed doc layout * Addressed some comments * properly pass down fit_params * change default value of force_exhaust_resources and update doc * should fix doc * Used check_fit_params * Update section about min_resources and number of candidates * Clarified ratio section * Use ~ to refer to classes * fixed doc checks * Apply suggestions from code review Co-authored-by: Joel Nothman <joel.nothman@gmail.com> * Addressed easy comments from Joel * missed some * updated docstring of run_search * Used f strings instead of format * remove candidate duplication checks * fix example * Addressed easy comments * rotate ticks labels * Added discussion in the intro as suggested by Joel * Split examples into sections * minor changes * remove force_exhaust_budget and introduce min_resources=exhaust * some minor validation * Added a n_resources_ attribute * update examples * Addressed comments * passing CV instead of X,y * minor revert for handling fit_params * updated docs * fix len * whatsnew * Add test for sampling when all_list * minor change to top-k * Force CV splits to be consistent across calls * reorder parameters * reduced diff * added tests for top_k * put back doc for groups * not sure what went wrong * put import at its place * some comment * Addressed comments * Added tests for cv_results_ and base estimator inputs * pep8 * avoid monkeypatching * rename df * use Joel's suggestions for testing masks * Made it experimental * Should fix docs * whats new entry * Apply suggestions from code review Co-authored-by: Andreas Mueller <t3kcit@gmail.com> * Addressed comments to docs * Addressed comments in examples * minor doc update * minor renaming in UG * forgot some * some sad note about splitter statefulness :'( * Addressed comments * ratio -> factor Co-authored-by: Joel Nothman <joel.nothman@gmail.com> Co-authored-by: Andreas Mueller <t3kcit@gmail.com>
- Loading branch information
Showing
16 changed files
with
2,276 additions
and
29 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
122 changes: 122 additions & 0 deletions
122
examples/model_selection/plot_successive_halving_heatmap.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,122 @@ | ||
""" | ||
Comparison between grid search and successive halving | ||
===================================================== | ||
This example compares the parameter search performed by | ||
:class:`~sklearn.model_selection.HalvingGridSearchCV` and | ||
:class:`~sklearn.model_selection.GridSearchCV`. | ||
""" | ||
from time import time | ||
|
||
import matplotlib.pyplot as plt | ||
import numpy as np | ||
import pandas as pd | ||
|
||
from sklearn.svm import SVC | ||
from sklearn import datasets | ||
from sklearn.model_selection import GridSearchCV | ||
from sklearn.experimental import enable_successive_halving # noqa | ||
from sklearn.model_selection import HalvingGridSearchCV | ||
|
||
|
||
print(__doc__) | ||
|
||
# %% | ||
# We first define the parameter space for an :class:`~sklearn.svm.SVC` | ||
# estimator, and compute the time required to train a | ||
# :class:`~sklearn.model_selection.HalvingGridSearchCV` instance, as well as a | ||
# :class:`~sklearn.model_selection.GridSearchCV` instance. | ||
|
||
rng = np.random.RandomState(0) | ||
X, y = datasets.make_classification(n_samples=1000, random_state=rng) | ||
|
||
gammas = [1e-1, 1e-2, 1e-3, 1e-4, 1e-5, 1e-6, 1e-7] | ||
Cs = [1, 10, 100, 1e3, 1e4, 1e5] | ||
param_grid = {'gamma': gammas, 'C': Cs} | ||
|
||
clf = SVC(random_state=rng) | ||
|
||
tic = time() | ||
gsh = HalvingGridSearchCV(estimator=clf, param_grid=param_grid, factor=2, | ||
random_state=rng) | ||
gsh.fit(X, y) | ||
gsh_time = time() - tic | ||
|
||
tic = time() | ||
gs = GridSearchCV(estimator=clf, param_grid=param_grid) | ||
gs.fit(X, y) | ||
gs_time = time() - tic | ||
|
||
# %% | ||
# We now plot heatmaps for both search estimators. | ||
|
||
|
||
def make_heatmap(ax, gs, is_sh=False, make_cbar=False): | ||
"""Helper to make a heatmap.""" | ||
results = pd.DataFrame.from_dict(gs.cv_results_) | ||
results['params_str'] = results.params.apply(str) | ||
if is_sh: | ||
# SH dataframe: get mean_test_score values for the highest iter | ||
scores_matrix = results.sort_values('iter').pivot_table( | ||
index='param_gamma', columns='param_C', | ||
values='mean_test_score', aggfunc='last' | ||
) | ||
else: | ||
scores_matrix = results.pivot(index='param_gamma', columns='param_C', | ||
values='mean_test_score') | ||
|
||
im = ax.imshow(scores_matrix) | ||
|
||
ax.set_xticks(np.arange(len(Cs))) | ||
ax.set_xticklabels(['{:.0E}'.format(x) for x in Cs]) | ||
ax.set_xlabel('C', fontsize=15) | ||
|
||
ax.set_yticks(np.arange(len(gammas))) | ||
ax.set_yticklabels(['{:.0E}'.format(x) for x in gammas]) | ||
ax.set_ylabel('gamma', fontsize=15) | ||
|
||
# Rotate the tick labels and set their alignment. | ||
plt.setp(ax.get_xticklabels(), rotation=45, ha="right", | ||
rotation_mode="anchor") | ||
|
||
if is_sh: | ||
iterations = results.pivot_table(index='param_gamma', | ||
columns='param_C', values='iter', | ||
aggfunc='max').values | ||
for i in range(len(gammas)): | ||
for j in range(len(Cs)): | ||
ax.text(j, i, iterations[i, j], | ||
ha="center", va="center", color="w", fontsize=20) | ||
|
||
if make_cbar: | ||
fig.subplots_adjust(right=0.8) | ||
cbar_ax = fig.add_axes([0.85, 0.15, 0.05, 0.7]) | ||
fig.colorbar(im, cax=cbar_ax) | ||
cbar_ax.set_ylabel('mean_test_score', rotation=-90, va="bottom", | ||
fontsize=15) | ||
|
||
|
||
fig, axes = plt.subplots(ncols=2, sharey=True) | ||
ax1, ax2 = axes | ||
|
||
make_heatmap(ax1, gsh, is_sh=True) | ||
make_heatmap(ax2, gs, make_cbar=True) | ||
|
||
ax1.set_title('Successive Halving\ntime = {:.3f}s'.format(gsh_time), | ||
fontsize=15) | ||
ax2.set_title('GridSearch\ntime = {:.3f}s'.format(gs_time), fontsize=15) | ||
|
||
plt.show() | ||
|
||
# %% | ||
# The heatmaps show the mean test score of the parameter combinations for an | ||
# :class:`~sklearn.svm.SVC` instance. The | ||
# :class:`~sklearn.model_selection.HalvingGridSearchCV` also shows the | ||
# iteration at which the combinations where last used. The combinations marked | ||
# as ``0`` were only evaluated at the first iteration, while the ones with | ||
# ``5`` are the parameter combinations that are considered the best ones. | ||
# | ||
# We can see that the :class:`~sklearn.model_selection.HalvingGridSearchCV` | ||
# class is able to find parameter combinations that are just as accurate as | ||
# :class:`~sklearn.model_selection.GridSearchCV`, in much less time. |
84 changes: 84 additions & 0 deletions
84
examples/model_selection/plot_successive_halving_iterations.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
""" | ||
Successive Halving Iterations | ||
============================= | ||
This example illustrates how a successive halving search ( | ||
:class:`~sklearn.model_selection.HalvingGridSearchCV` and | ||
:class:`~sklearn.model_selection.HalvingRandomSearchCV`) iteratively chooses | ||
the best parameter combination out of multiple candidates. | ||
""" | ||
import pandas as pd | ||
from sklearn import datasets | ||
import matplotlib.pyplot as plt | ||
from scipy.stats import randint | ||
import numpy as np | ||
|
||
from sklearn.experimental import enable_successive_halving # noqa | ||
from sklearn.model_selection import HalvingRandomSearchCV | ||
from sklearn.ensemble import RandomForestClassifier | ||
|
||
|
||
print(__doc__) | ||
|
||
# %% | ||
# We first define the parameter space and train a | ||
# :class:`~sklearn.model_selection.HalvingRandomSearchCV` instance. | ||
|
||
rng = np.random.RandomState(0) | ||
|
||
X, y = datasets.make_classification(n_samples=700, random_state=rng) | ||
|
||
clf = RandomForestClassifier(n_estimators=20, random_state=rng) | ||
|
||
param_dist = {"max_depth": [3, None], | ||
"max_features": randint(1, 11), | ||
"min_samples_split": randint(2, 11), | ||
"bootstrap": [True, False], | ||
"criterion": ["gini", "entropy"]} | ||
|
||
rsh = HalvingRandomSearchCV( | ||
estimator=clf, | ||
param_distributions=param_dist, | ||
factor=2, | ||
random_state=rng) | ||
rsh.fit(X, y) | ||
|
||
# %% | ||
# We can now use the `cv_results_` attribute of the search estimator to inspect | ||
# and plot the evolution of the search. | ||
|
||
results = pd.DataFrame(rsh.cv_results_) | ||
results['params_str'] = results.params.apply(str) | ||
results.drop_duplicates(subset=('params_str', 'iter'), inplace=True) | ||
mean_scores = results.pivot(index='iter', columns='params_str', | ||
values='mean_test_score') | ||
ax = mean_scores.plot(legend=False, alpha=.6) | ||
|
||
labels = [ | ||
f'iter={i}\nn_samples={rsh.n_resources_[i]}\n' | ||
f'n_candidates={rsh.n_candidates_[i]}' | ||
for i in range(rsh.n_iterations_) | ||
] | ||
ax.set_xticklabels(labels, rotation=45, multialignment='left') | ||
ax.set_title('Scores of candidates over iterations') | ||
ax.set_ylabel('mean test score', fontsize=15) | ||
ax.set_xlabel('iterations', fontsize=15) | ||
plt.tight_layout() | ||
plt.show() | ||
|
||
# %% | ||
# Number of candidates and amount of resource at each iteration | ||
# ------------------------------------------------------------- | ||
# | ||
# At the first iteration, a small amount of resources is used. The resource | ||
# here is the number of samples that the estimators are trained on. All | ||
# candidates are evaluated. | ||
# | ||
# At the second iteration, only the best half of the candidates is evaluated. | ||
# The number of allocated resources is doubled: candidates are evaluated on | ||
# twice as many samples. | ||
# | ||
# This process is repeated until the last iteration, where only 2 candidates | ||
# are left. The best candidate is the candidate that has the best score at the | ||
# last iteration. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
"""Enables Successive Halving search-estimators | ||
The API and results of these estimators might change without any deprecation | ||
cycle. | ||
Importing this file dynamically sets the | ||
:class:`~sklearn.model_selection.HalvingRandomSearchCV` and | ||
:class:`~sklearn.model_selection.HalvingGridSearchCV` as attributes of the | ||
`model_selection` module:: | ||
>>> # explicitly require this experimental feature | ||
>>> from sklearn.experimental import enable_successive_halving # noqa | ||
>>> # now you can import normally from model_selection | ||
>>> from sklearn.model_selection import HalvingRandomSearchCV | ||
>>> from sklearn.model_selection import HalvingGridSearchCV | ||
The ``# noqa`` comment comment can be removed: it just tells linters like | ||
flake8 to ignore the import, which appears as unused. | ||
""" | ||
|
||
from ..model_selection._search_successive_halving import ( | ||
HalvingRandomSearchCV, | ||
HalvingGridSearchCV | ||
) | ||
|
||
from .. import model_selection | ||
|
||
# use settattr to avoid mypy errors when monkeypatching | ||
setattr(model_selection, "HalvingRandomSearchCV", | ||
HalvingRandomSearchCV) | ||
setattr(model_selection, "HalvingGridSearchCV", | ||
HalvingGridSearchCV) | ||
|
||
model_selection.__all__ += ['HalvingRandomSearchCV', 'HalvingGridSearchCV'] |
43 changes: 43 additions & 0 deletions
43
sklearn/experimental/tests/test_enable_successive_halving.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
"""Tests for making sure experimental imports work as expected.""" | ||
|
||
import textwrap | ||
|
||
from sklearn.utils._testing import assert_run_python_script | ||
|
||
|
||
def test_imports_strategies(): | ||
# Make sure different import strategies work or fail as expected. | ||
|
||
# Since Python caches the imported modules, we need to run a child process | ||
# for every test case. Else, the tests would not be independent | ||
# (manually removing the imports from the cache (sys.modules) is not | ||
# recommended and can lead to many complications). | ||
|
||
good_import = """ | ||
from sklearn.experimental import enable_successive_halving | ||
from sklearn.model_selection import HalvingGridSearchCV | ||
from sklearn.model_selection import HalvingRandomSearchCV | ||
""" | ||
assert_run_python_script(textwrap.dedent(good_import)) | ||
|
||
good_import_with_model_selection_first = """ | ||
import sklearn.model_selection | ||
from sklearn.experimental import enable_successive_halving | ||
from sklearn.model_selection import HalvingGridSearchCV | ||
from sklearn.model_selection import HalvingRandomSearchCV | ||
""" | ||
assert_run_python_script( | ||
textwrap.dedent(good_import_with_model_selection_first) | ||
) | ||
|
||
bad_imports = """ | ||
import pytest | ||
with pytest.raises(ImportError): | ||
from sklearn.model_selection import HalvingGridSearchCV | ||
import sklearn.experimental | ||
with pytest.raises(ImportError): | ||
from sklearn.model_selection import HalvingGridSearchCV | ||
""" | ||
assert_run_python_script(textwrap.dedent(bad_imports)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.