Skip to content

Commit

Permalink
Use proper meta estimator semantics
Browse files Browse the repository at this point in the history
  • Loading branch information
TomAugspurger committed Jul 2, 2018
1 parent c615c63 commit 55fccf2
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 54 deletions.
57 changes: 21 additions & 36 deletions dask_ml/wrappers.py
Expand Up @@ -46,6 +46,12 @@ class ParallelPostFit(sklearn.base.BaseEstimator):
a single NumPy array, which may exhaust the memory of your worker.
You probably want to always specify `scoring`.
name : string, default 'estimator'
The name to use for the underlying estimator. This is useful in
settings using ``get_params`` and ``set_param``, for example when
performing grid search. To target the parameters of the underlying
estimator, use ``<name>__<parameter>``.
Notes
-----
Expand Down Expand Up @@ -322,10 +328,12 @@ class Incremental(ParallelPostFit):
a single NumPy array, which may exhaust the memory of your worker.
You probably want to always specify `scoring`.
**kwargs
Set the hyperparameters of `estimator`. This is used in, for example,
``GridSearchCV``, to set the paramters of the underlying estimator.
Most of the time you will not need to use this.
name : string, default 'estimator'
The name to use for the underlying estimator. This is useful in
settings using ``get_params`` and ``set_param``, for example when
performing grid search. To target the parameters of the underlying
estimator, use ``<name>__<parameter>``.
Attributes
----------
Expand All @@ -343,17 +351,17 @@ class Incremental(ParallelPostFit):
>>> import sklearn.linear_model
>>> X, y = make_classification(chunks=25)
>>> est = sklearn.linear_model.SGDClassifier()
>>> clf = Incremental(est)
>>> clf = Incremental(est, scoring='accuracy')
>>> clf.fit(X, y, classes=[0, 1])
"""
_estimator_clash_message = (
"The 'estimator' parameter is used by both 'Incremental' and the"
"underlying estimator, which will produce incorrect results."
)
def __init__(self, estimator, scoring=None, **kwargs):
estimator.set_params(**kwargs)
super(Incremental, self).__init__(estimator=estimator, scoring=scoring)
When used inside a grid search, prefix the underlying estimator's
parameter names with `name`.
>>> from sklearn.model_selection import GridSearchCV
>>> param_grid = {"estimator__alpha": [0.1, 1.0, 10.0]}
>>> gs = GridSearchCV(Incremental(), param_grid)
>>> gs.fit(X, y, classes=[0, 1])
"""

@property
def _postfit_estimator(self):
Expand Down Expand Up @@ -397,29 +405,6 @@ def partial_fit(self, X, y=None, **fit_kwargs):
estimator = sklearn.base.clone(self.estimator)
return self._fit_for_estimator(estimator, X, y, **fit_kwargs)

def __repr__(self):
# Have to override, else all the parameters of estimator
# are duplicated
estimator = repr(self.estimator)
class_name = self.__class__.__name__
return '{}({})'.format(class_name, estimator)

def get_params(self, deep=True):
out = self.estimator.get_params(deep=deep)
if 'estimator' in out:
raise ValueError(self._estimator_clash_message)
out['estimator'] = self.estimator
out['scoring'] = self.scoring
return out

def set_params(self, **kwargs):
if 'estimator' in kwargs:
raise ValueError(self._estimator_clash_message)
if 'scoring' in kwargs:
self.scoring = kwargs['scoring']
self.estimator.set_params(**kwargs)
return self


def _first_block(dask_object):
"""Extract the first block / partition from a dask object
Expand Down
40 changes: 22 additions & 18 deletions tests/test_incremental.py
Expand Up @@ -13,6 +13,24 @@
from dask_ml.metrics.scorer import check_scoring


def test_get_params():
clf = Incremental(SGDClassifier())
result = clf.get_params()

assert 'estimator__alpha' in result
assert result['scoring'] is None


def test_set_params():
clf = Incremental(SGDClassifier())
clf.set_params(**{'scoring': 'accuracy',
'estimator__alpha': 0.1})
result = clf.get_params()

assert result['estimator__alpha'] == 0.1
assert result['scoring'] == 'accuracy'


def test_incremental_basic(scheduler, xy_classification):
X, y = xy_classification

Expand Down Expand Up @@ -54,28 +72,14 @@ def test_incremental_basic(scheduler, xy_classification):

def test_in_gridsearch(scheduler, xy_classification):
X, y = xy_classification
clf = Incremental(SGDClassifier(random_state=0, tol=1e-3))
param_grid = {'estimator__alpha': [0.1, 10]}
gs = sklearn.model_selection.GridSearchCV(clf, param_grid, iid=False)

with scheduler() as (s, [a, b]):
clf = Incremental(SGDClassifier(random_state=0, tol=1e-3))
param_grid = {'alpha': [0.1, 10]}
gs = sklearn.model_selection.GridSearchCV(clf, param_grid, iid=False)
gs.fit(X, y, classes=[0, 1])


def test_estimator_param_raises():

class Dummy(sklearn.base.BaseEstimator):
def __init__(self, estimator=42):
self.estimator = estimator

def fit(self, X):
return self

clf = Incremental(Dummy(estimator=1))

with pytest.raises(ValueError, match='used by both'):
clf.get_params()


def test_scoring(scheduler, xy_classification,
scoring=dask_ml.metrics.accuracy_score):
X, y = xy_classification
Expand Down

0 comments on commit 55fccf2

Please sign in to comment.