Skip to content

Commit

Permalink
Add Chained meta-estimator
Browse files Browse the repository at this point in the history
Chained takes a scikit-learn estimator that supports `partial_fit`, and
implements a `fit` method that chains calls to `partial_fit` along the
partitions of the `X` and `y` inputs.
  • Loading branch information
jcrist committed Jun 17, 2016
1 parent 0c3d0f6 commit b59d53b
Show file tree
Hide file tree
Showing 9 changed files with 479 additions and 55 deletions.
1 change: 1 addition & 0 deletions dklearn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,6 @@
from .core import from_sklearn
from .pipeline import Pipeline
from .estimator import Estimator
from .chained import Chained

__version__ = '0.1.0'
140 changes: 140 additions & 0 deletions dklearn/chained.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
from __future__ import absolute_import, print_function, division

import numpy as np
from dask.base import tokenize
from dask.delayed import delayed
from sklearn.base import clone, BaseEstimator, is_classifier
from scipy import sparse

from .core import DaskBaseEstimator
from .estimator import Estimator
from .utils import unpack_arguments, unpack_as_lists_of_keys, check_X_y


_unique_chunk = delayed(np.unique, pure=True)


@delayed(pure=True)
def _unique_merge(x):
return np.unique(np.concatenate(x))


def _maybe_stack(x):
"""Given a list of arrays, maybe stack them along their first axis.k
Works with both sparse and dense arrays."""
if isinstance(x, (tuple, list)):
# optimization to avoid copies if unneeded
if len(x) == 1:
return x[0]
if isinstance(x[0], np.ndarray):
return np.concatenate(x)
elif sparse.issparse(x[0]):
return sparse.vstack(x)
return x


def _partial_fit(est, X, y, classes, kwargs):
# XXX: this mutates est!
X = _maybe_stack(X)
y = _maybe_stack(y)
if classes is None:
return est.partial_fit(X, y, **kwargs)
return est.partial_fit(X, y, classes=classes, **kwargs)


class Chained(DaskBaseEstimator, BaseEstimator):
_finalize = staticmethod(lambda res: Chained(res[0]))

def __init__(self, estimator):
if not isinstance(estimator, (BaseEstimator, Estimator)):
raise TypeError("`estimator` must a scikit-learn estimator "
"or a dklearn.Estimator")

if not hasattr(estimator, 'partial_fit'):
raise ValueError("estimator must support `partial_fit`")

est = Estimator.from_sklearn(estimator)
object.__setattr__(self, 'estimator', est)

@property
def dask(self):
return self.estimator.dask

@property
def _name(self):
return self.estimator._name

@property
def _estimator_type(self):
return self.estimator._estimator_type

def set_params(self, **params):
if not params:
return self
if 'estimator' in params:
if len(params) == 1:
return Chained(params['estimator'])
raise ValueError("Setting params with both `'estimator'` and "
"nested parameters is ambiguous due to order of "
"operations. To change both estimator and "
"sub-parameters create a new `Chained`.")
sub_params = {}
for key, value in params.items():
split = key.split('__', 1)
if len(split) > 1 and split[0] == 'estimator':
sub_params[split[1]] = value
else:
raise ValueError('Invalid parameter %s for estimator %s. '
'Check the list of available parameters '
'with `estimator.get_params().keys()`.' %
(key, self.__class__.__name__))
return Chained(self.estimator.set_params(**sub_params))

def __setattr__(self, k, v):
raise AttributeError("Attribute setting not permitted. "
"Use `set_params` to change parameters")

@classmethod
def from_sklearn(cls, est):
if isinstance(est, cls):
return est
return cls(est)

def to_sklearn(self, compute=True):
return self.estimator.to_sklearn(compute=compute)

def fit(self, X, y, **kwargs):
X, y = check_X_y(X, y)
x_parts, y_parts, dsk = unpack_as_lists_of_keys(X, y)
name = 'partial_fit-' + tokenize(self, X, y, **kwargs)

# Extract classes if applicable
if is_classifier(self):
classes = kwargs.pop('classes', None)
if classes is None:
classes = _unique_merge([_unique_chunk(i) for i in y_parts])
classes, dsk2 = unpack_arguments(classes)
dsk.update(dsk2)
else:
classes = None

# Clone so that this estimator isn't mutated
sk_est = clone(self.estimator._est)

dsk[(name, 0)] = (_partial_fit, sk_est, x_parts[0], y_parts[0],
classes, kwargs)

for i, (x, y) in enumerate(zip(x_parts[1:], y_parts[1:]), 1):
dsk[(name, i)] = (_partial_fit, (name, i - 1), x, y, None, kwargs)
out = Estimator(clone(sk_est), dsk, (name, len(x_parts) - 1))
return Chained(out)

def predict(self, X):
return self.estimator.predict(X)

def score(self, X, y, **kwargs):
return self.estimator.score(X, y, **kwargs)

def transform(self, X):
return self.estimator.transform(X)
49 changes: 6 additions & 43 deletions dklearn/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,12 @@

from functools import partial

from dask.base import Base, normalize_token, tokenize
from dask.delayed import Delayed
from dask.base import Base, normalize_token
from dask.optimize import fuse
from dask.threaded import get as threaded_get
from dask.utils import concrete, Dispatch
from dask.utils import Dispatch
from sklearn.base import BaseEstimator
from toolz import merge, identity
from toolz import identity


class DaskBaseEstimator(Base):
Expand All @@ -20,6 +19,9 @@ def _optimize(dsk, keys, **kwargs):
dsk2, deps = fuse(dsk, keys)
return dsk2

def _keys(self):
return [self._name]


@partial(normalize_token.register, BaseEstimator)
def normalize_BaseEstimator(est):
Expand All @@ -31,45 +33,6 @@ def normalize_dask_estimators(est):
return type(est).__name__, est._name


def unpack_arguments(*args):
"""Extracts dask values from args"""
out_args = []
dsks = []
for x in args:
t, d = unpack(x)
out_args.append(t)
dsks.extend(d)
dsk = merge(dsks)
return tuple(out_args) + (dsk,)


def unpack(expr):
"""Normalize a python object and extract all sub-dasks.
Parameters
----------
expr : object
The object to be normalized.
Returns
-------
task : normalized task to be run
dasks : list of dasks that form the dag for this task
"""
if isinstance(expr, Delayed):
return expr.key, expr._dasks
if isinstance(expr, Base):
name = tokenize(expr, pure=True)
keys = expr._keys()
if isinstance(expr, DaskBaseEstimator):
dsk = expr.dask
else:
dsk = expr._optimize(expr.dask, keys)
dsk[name] = (expr._finalize, (concrete, keys))
return name, [dsk]
return expr, []


def from_sklearn(est):
"""Wrap a scikit-learn estimator in a dask object."""
return from_sklearn.dispatch(est)
Expand Down
12 changes: 6 additions & 6 deletions dklearn/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@

from operator import getitem

from sklearn.base import clone, BaseEstimator
from dask.base import tokenize
from dask.delayed import Delayed
from sklearn.base import clone, BaseEstimator
from toolz import merge

from .core import DaskBaseEstimator, unpack_arguments, from_sklearn
from .core import DaskBaseEstimator, from_sklearn
from .utils import unpack_arguments


def _fit(est, X, y, kwargs):
Expand Down Expand Up @@ -62,20 +63,19 @@ def __init__(self, est, dask=None, name=None):
raise TypeError("Expected instance of BaseEstimator, "
"got {0}".format(type(est).__name__))
if dask is None and name is None:
name = 'from_sklearn-' + tokenize(est)
name = 'estimator-' + tokenize(est)
dask = {name: est}
elif dask is None or name is None:
raise ValueError("Must provide both dask and name")
self.dask = dask
self._name = name
self._est = est

def _keys(self):
return [self._name]

@classmethod
def from_sklearn(cls, est):
"""Wrap a scikit-learn estimator"""
if isinstance(est, cls):
return est
return cls(est)

def to_sklearn(self, compute=True):
Expand Down
6 changes: 1 addition & 5 deletions dklearn/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,6 @@ def dask(self):
self._dask = dsk
return dsk

def _keys(self):
return [self._name]

@classmethod
def from_sklearn(cls, est):
if not isinstance(est, pipeline.Pipeline):
Expand All @@ -51,7 +48,6 @@ def to_sklearn(self, compute=True):

def set_params(self, **params):
if not params:
# Simple optimisation to gain speed (inspect is slow)
return self
if 'steps' in params:
if len(params) == 1:
Expand All @@ -64,7 +60,7 @@ def set_params(self, **params):
sub_params = dict((n, {}) for n in self.named_steps)
for key, value in params.items():
split = key.split('__', 1)
if len(split) > 1 and split[0] in sub_params and len(split) == 2:
if len(split) > 1 and split[0] in sub_params:
# nested objects case
sub_params[split[0]][split[1]] = value
else:
Expand Down

0 comments on commit b59d53b

Please sign in to comment.