Skip to content

Commit

Permalink
Change repr of estimators, reorganize
Browse files Browse the repository at this point in the history
- Change reprs of wrappers to be more informative
- Rename `Estimator` to `Wrapped` to better match `Chained`
- Rename estimator.py -> wrapped.py, and fixed accordingly.
  • Loading branch information
jcrist committed Jun 28, 2016
1 parent 827dc30 commit 09b7878
Show file tree
Hide file tree
Showing 9 changed files with 54 additions and 44 deletions.
2 changes: 1 addition & 1 deletion dklearn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from .core import from_sklearn
from .pipeline import Pipeline
from .estimator import Estimator
from .wrapped import Wrapped
from .chained import Chained

__version__ = '0.1.0'
6 changes: 2 additions & 4 deletions dklearn/chained.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from dask.delayed import delayed
from sklearn.base import clone, is_classifier

from .estimator import Estimator
from .wrapped import Wrapped
from .utils import unpack_arguments, unpack_as_lists_of_keys, check_X_y


Expand All @@ -24,9 +24,7 @@ def _partial_fit(est, X, y, classes, kwargs):
return est.partial_fit(X, y, classes=classes, **kwargs)


class Chained(Estimator):
_finalize = staticmethod(lambda res: res[0])

class Chained(Wrapped):
def __init__(self, est, dask=None, name=None):
super(Chained, self).__init__(est, dask=dask, name=name)
if not hasattr(est, 'partial_fit'):
Expand Down
2 changes: 1 addition & 1 deletion dklearn/feature_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from . import matrix as dm
from .core import from_sklearn
from .estimator import _transform, WrapperMixin
from .wrapped import _transform, WrapperMixin
from .utils import unpack_arguments


Expand Down
4 changes: 2 additions & 2 deletions dklearn/grid_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ class BaseSearchCV(BaseEstimator):
def __init__(self, estimator, scoring=None, fit_params=None, iid=True,
refit=True, cv=None, get=None):
self.scoring = scoring
self.estimator = from_sklearn(estimator)
self.estimator = estimator
self.fit_params = fit_params if fit_params is not None else {}
self.iid = iid
self.refit = refit
Expand All @@ -72,7 +72,7 @@ def _estimator_type(self):
return self.estimator._estimator_type

def _fit(self, X, y, parameter_iterable):
estimator = self.estimator
estimator = from_sklearn(self.estimator)
self.scorer_ = check_scoring(estimator, scoring=self.scoring)
cv = check_cv(self.cv, X, y, classifier=is_classifier(estimator))
n_folds = len(cv)
Expand Down
15 changes: 10 additions & 5 deletions dklearn/tests/test_chained.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from toolz import dissoc

import dklearn.matrix as dm
from dklearn.estimator import Estimator
from dklearn.wrapped import Wrapped
from dklearn.chained import Chained

# Not fit estimators raise NotFittedError, but old versions of scikit-learn
Expand Down Expand Up @@ -56,14 +56,13 @@ def test_tokenize_Chained():
c2 = Chained(sgd2)
assert tokenize(c1) == tokenize(c1)
assert tokenize(c1) != tokenize(c2)
assert tokenize(c1) != tokenize(Estimator(sgd1))
assert tokenize(c1) != tokenize(Wrapped(sgd1))


def test_clone():
c = Chained(sgd1)
c2 = clone(c)
assert (dissoc(c.get_params(), 'estimator') ==
dissoc(c2.get_params(), 'estimator'))
assert isinstance(c2, Chained)
assert c._name == c2._name
assert c._est is not c2._est

Expand Down Expand Up @@ -95,7 +94,7 @@ def test_set_params():
def test_setattr():
c = Chained(sgd1)
with pytest.raises(AttributeError):
c.estimator = sgd2
c.penalty = 'l2'


def test_getattr():
Expand All @@ -111,6 +110,12 @@ def test_dir():
assert 'penalty' in attrs


def test_repr():
c = Chained(sgd1)
res = repr(c)
assert res.startswith('Chained')


def fit_test(c, X, y):
fit = c.fit(X, y)
assert fit is not c
Expand Down
3 changes: 1 addition & 2 deletions dklearn/tests/test_feature_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,7 @@ def test_dir(self):

def test_repr(self):
d = from_sklearn(self.sk)
res = repr(d)
assert res.startswith('Dask')
assert repr(d) == repr(self.sk)

def test_fit(self):
d = from_sklearn(self.sk)
Expand Down
10 changes: 5 additions & 5 deletions dklearn/tests/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from toolz import dissoc

from dklearn import from_sklearn
from dklearn.estimator import Estimator
from dklearn.wrapped import Wrapped
from dklearn.pipeline import Pipeline

digits = load_digits()
Expand Down Expand Up @@ -81,7 +81,7 @@ def test_clone():
lr2 = d2.named_steps['logistic']
assert lr1 is not lr2
assert lr1.get_params() == lr2.get_params()
assert isinstance(lr2, Estimator)
assert isinstance(lr2, Wrapped)


def test__estimator_type():
Expand Down Expand Up @@ -121,7 +121,7 @@ def test_set_params():
lr1 = d.named_steps['logistic']
lr2 = d2.named_steps['logistic']
assert lr1.get_params() == lr2.get_params()
assert isinstance(lr2, Estimator)
assert isinstance(lr2, Wrapped)

# Fast return
d2 = d.set_params()
Expand All @@ -138,8 +138,8 @@ def test_set_params():
def test_named_steps():
d = from_sklearn(pipe1)
steps = d.named_steps
assert isinstance(steps['pca'], Estimator)
assert isinstance(steps['logistic'], Estimator)
assert isinstance(steps['pca'], Wrapped)
assert isinstance(steps['logistic'], Wrapped)


def test_setattr():
Expand Down
26 changes: 13 additions & 13 deletions dklearn/tests/test_estimator.py → dklearn/tests/test_wrapped.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from sklearn.linear_model import LogisticRegression

from dklearn import from_sklearn
from dklearn.estimator import Estimator
from dklearn.wrapped import Wrapped

# Not fit estimators raise NotFittedError, but old versions of scikit-learn
# include two definitions of this error, which makes it hard to catch
Expand Down Expand Up @@ -44,25 +44,25 @@ def test_from_sklearn():


def test_Estimator__init__():
d = Estimator(LogisticRegression(C=1000))
d = Wrapped(LogisticRegression(C=1000))
assert d._name == from_sklearn(clf1)._name

with pytest.raises(ValueError):
Estimator(clf1, name='foo')
Wrapped(clf1, name='foo')

with pytest.raises(TypeError):
Estimator("not an estimator")
Wrapped("not an estimator")


def test_clone():
d = Estimator(clf1)
d = Wrapped(clf1)
d2 = clone(d)
assert d.get_params() == d2.get_params()
assert d._est is not d2._est


def test__estimator_type():
d = Estimator(clf1)
d = Wrapped(clf1)
assert d._estimator_type == clf1._estimator_type


Expand All @@ -75,7 +75,7 @@ def test_get_params():
def test_set_params():
d = from_sklearn(clf1)
d2 = d.set_params(C=5)
assert isinstance(d2, Estimator)
assert isinstance(d2, Wrapped)
# Check no mutation
assert d2.get_params()['C'] == 5
assert d2.compute().C == 5
Expand All @@ -84,35 +84,35 @@ def test_set_params():


def test_setattr():
d = Estimator(clf1)
d = Wrapped(clf1)
with pytest.raises(AttributeError):
d.C = 10


def test_getattr():
d = Estimator(clf1)
d = Wrapped(clf1)
assert d.C == clf1.C
with pytest.raises(AttributeError):
d.not_a_real_parameter


def test_dir():
d = Estimator(clf1)
d = Wrapped(clf1)
attrs = dir(d)
assert 'C' in attrs


def test_repr():
d = Estimator(clf1)
d = Wrapped(clf1)
res = repr(d)
assert res.startswith('Dask')
assert res.startswith('Wrapped')


def test_fit():
d = from_sklearn(clf1)
fit = d.fit(X_iris, y_iris)
assert fit is not d
assert isinstance(fit, Estimator)
assert isinstance(fit, Wrapped)

res = fit.compute()
assert hasattr(res, 'coef_')
Expand Down
30 changes: 19 additions & 11 deletions dklearn/estimator.py → dklearn/wrapped.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from dask.delayed import Delayed
from sklearn.base import clone, BaseEstimator
from toolz import merge
from textwrap import wrap

from .core import DaskBaseEstimator, from_sklearn
from .utils import unpack_arguments
Expand Down Expand Up @@ -39,23 +40,20 @@ def _fit_transform(est, X, y, kwargs):


class ClassProxy(object):
def __init__(self, cls):
def __init__(self, wrap, cls):
self.wrap = wrap
self.cls = cls

@property
def __name__(self):
return 'Dask' + self.cls.__name__
return self.wrap.__name__

def __call__(self, *args, **kwargs):
return Estimator(self.cls(*args, **kwargs))
return self.wrap(self.cls(*args, **kwargs))


class WrapperMixin(DaskBaseEstimator):
"""Mixin class for dask estimators that wrap sklearn estimators"""
@property
def __class__(self):
return ClassProxy(type(self._est))

@classmethod
def _finalize(cls, res):
return res[0]
Expand Down Expand Up @@ -94,7 +92,7 @@ def __dir__(self):
return list(o)


class Estimator(WrapperMixin, BaseEstimator):
class Wrapped(WrapperMixin, BaseEstimator):
"""A class for wrapping a scikit-learn estimator.
All operations done on this estimator are pure (no mutation), and are done
Expand All @@ -114,6 +112,16 @@ def __init__(self, est, dask=None, name=None):
self._name = name
self._est = est

@property
def __class__(self):
return ClassProxy(type(self), type(self._est))

def __repr__(self):
class_name = type(self).__name__
est = ''.join(map(str.strip, repr(self._est).splitlines()))
return '\n'.join(wrap('{0}({1})'.format(class_name, est),
subsequent_indent=" "*10))

@classmethod
def from_sklearn(cls, est):
"""Wrap a scikit-learn estimator"""
Expand All @@ -126,7 +134,7 @@ def fit(self, X, y, **kwargs):
X, y, dsk = unpack_arguments(X, y)
dsk.update(self.dask)
dsk[name] = (_fit, self._name, X, y, kwargs)
return Estimator(self._est, dsk, name)
return Wrapped(self._est, dsk, name)

def predict(self, X):
name = 'predict-' + tokenize(self, X)
Expand Down Expand Up @@ -155,7 +163,7 @@ def _fit_transform(self, X, y, **kwargs):
dsk[fit_tr_name] = (_fit_transform, self._name, X, y, kwargs)
dsk1 = merge({fit_name: (getitem, fit_tr_name, 0)}, dsk, self.dask)
dsk2 = merge({tr_name: (getitem, fit_tr_name, 1)}, dsk, self.dask)
return Estimator(self._est, dsk1, fit_name), Delayed(tr_name, [dsk2])
return Wrapped(self._est, dsk1, fit_name), Delayed(tr_name, [dsk2])


from_sklearn.dispatch.register(BaseEstimator, Estimator.from_sklearn)
from_sklearn.dispatch.register(BaseEstimator, Wrapped.from_sklearn)

0 comments on commit 09b7878

Please sign in to comment.