Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Blockwise Metaestimator #190

Merged
merged 8 commits into from
Jun 4, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions dask_ml/_partial.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import absolute_import, division, print_function

import os
import warnings
from abc import ABCMeta

import numpy as np
Expand All @@ -22,6 +23,13 @@ class _WritableDoc(ABCMeta):
# TODO: Py2: remove all this


_partial_deprecation = (
"'{cls.__name__}' is deprecated. Use "
"'dask_ml.wrappers.Incremental({base.__name__}(), **kwargs)' "
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should Incremental be top-level? This might be a broader discussion though about namespaces.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I dislike the .wrappers namespace, but I haven't put much thought into a replacement.

"instead."
)


@six.add_metaclass(_WritableDoc)
class _BigPartialFitMixin(object):
""" Wraps a partial_fit enabled estimator for use with Dask arrays """
Expand All @@ -30,6 +38,7 @@ class _BigPartialFitMixin(object):
_fit_kwargs = []

def __init__(self, **kwargs):
self._deprecated()
missing = set(self._init_kwargs) - set(kwargs)

if missing:
Expand All @@ -40,6 +49,15 @@ def __init__(self, **kwargs):
setattr(self, kwarg, kwargs.pop(kwarg))
super(_BigPartialFitMixin, self).__init__(**kwargs)

@classmethod
def _deprecated(cls):
for base in cls.mro():
if base.__module__.startswith('sklearn'):
break

warnings.warn(_partial_deprecation.format(cls=cls, base=base),
FutureWarning)

@classmethod
def _get_param_names(cls):
# Evil hack to make sure repr, get_params work
Expand Down Expand Up @@ -194,6 +212,9 @@ def _copy_partial_doc(cls):

insert = """

.. deprecated:: 0.6.0
Use the :class:`dask_ml.wrappers.Incremental` meta-estimator instead.

This class wraps scikit-learn's {classname}. When a dask-array is passed
to our ``fit`` method, the array is passed block-wise to the scikit-learn
class' ``partial_fit`` method. This will allow you to fit the estimator
Expand Down
94 changes: 93 additions & 1 deletion dask_ml/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,27 @@ def transform(self, X):
return transform(X)

def score(self, X, y):
# TODO: re-implement some scoring functions.
"""Returns the score on the given data.

This uses the scoring defined by ``estimator.score``. This is
currently immediate and sequential. In the future, this will be
delayed and parallel.

Parameters
----------
X : array-like, shape = [n_samples, n_features]
Input data, where n_samples is the number of samples and
n_features is the number of features.

y : array-like, shape = [n_samples] or [n_samples, n_output], optional
Target relative to X for classification or regression;
None for unsupervised learning.

Returns
-------
score : float
return self.estimator.score(X, y)
"""
return self.estimator.score(X, y)

def predict(self, X):
Expand Down Expand Up @@ -199,6 +219,78 @@ def _check_method(self, method):
return getattr(self.estimator, method)


class Incremental(ParallelPostFit):
"""Metaestimator for feeding Dask Arrays to an estimator blockwise.

This wrapper provides a bridge between Dask objects and estimators
implementing the ``partial_fit`` API. These *incremental learners* can
train on batches of data. This fits well with Dask's blocked data
structures.

See the `list of incremental learners`_ in the scikit-learn documentation
for a list of estimators that implement the ``partial_fit`` API. Note that
`Incremental` is not limited to just these classes, it will work on any
estimator implementing ``partial_fit``, including those defined outside of
scikit-learn itself.

Calling :meth:`Incremental.fit` with a Dask Array will pass each block of
the Dask array or arrays to ``estimator.partial_fit`` *sequentially*.

Like :class:`ParallelPostFit`, the methods available after fitting (e.g.
:meth:`Incremental.predict`, etc.) are all parallel and delayed.

.. _list of incremental learners: http://scikit-learn.org/stable/modules/scaling_strategies.html#incremental-learning # noqa

Parameters
----------
estimator : Estimator
Any object supporting the scikit-learn ``parital_fit`` API.
**kwargs
Additional keyword arguments passed through the the underlying
estimator's `partial_fit` method.

Examples
--------
>>> from dask_ml.wrappers import Incremental
>>> from dask_ml.datasets import make_classification
>>> import sklearn.linear_model
>>> X, y = make_classification(chunks=25)
>>> est = sklearn.linear_model.SGDClassifier()
>>> clf = Incremental(est, classes=[0, 1])
>>> clf.fit(X, y)
"""
def __init__(self, estimator, **kwargs):
self.estimator = estimator
self.fit_kwargs = kwargs

def fit(self, X, y=None):
from ._partial import fit

fit_kwargs = self.fit_kwargs or {}
result = fit(self.estimator, X, y, **fit_kwargs)

# Copy the learned attributes over to self
attrs = {k: v for k, v in vars(result).items() if k.endswith('_')}
for k, v in attrs.items():
setattr(self, k, v)
return self

def partial_fit(self, X, y=None):
"""Fit the underlying estimator.

This is identical to ``fit``.

Parameters
----------
X, y : array-like

Returns
-------
self : object
"""
return self.fit(X, y)


def _first_block(dask_object):
"""Extract the first block / partition from a dask object
"""
Expand Down
8 changes: 7 additions & 1 deletion docs/source/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,13 @@ Version 0.6.0
API Breaking Changes
--------------------

- Removed the `get` keyword from the incremental learner ``fit`` methods (:pr:`187`)
- Removed the `get` keyword from the incremental learner ``fit`` methods. (:pr:`187`)
- Deprecated the various ``Partial*`` estimators in favor of the :class:`dask_ml.wrappers.Incremental` meta-estimator (:pr:`190`)

Enhancements
------------

- Added a new meta-estimator :class:`dask_ml.wrappers.Incremental` for wrapping any estimator with a `partial_fit` method. See :ref:`incremental.blockwise-metaestimator` for more. (:pr:`190`)

Version 0.5.0
~~~~~~~~~~~~~
Expand Down
90 changes: 52 additions & 38 deletions docs/source/incremental.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,55 +3,69 @@
Incremental Learning
====================

.. currentmodule:: dask_ml
Some estimators can be trained incrementally -- without seeing the entire
dataset at once. Scikit-Learn provdes the ``partial_fit`` API to stream batches
of data to an estimator that can be fit in batches.

.. autosummary::
naive_bayes.PartialBernoulliNB
naive_bayes.PartialMultinomialNB
linear_model.PartialSGDRegressor
linear_model.PartialSGDClassifier
linear_model.PartialPerceptron
linear_model.PartialPassiveAggressiveClassifier
linear_model.PartialPassiveAggressiveRegressor
cluster.PartialMiniBatchKMeans
base._BigPartialFitMixin

Scikit-Learn's Partial Fit
--------------------------
Normally, if you pass a Dask Array to an estimator expecting a NumPy array,
the Dask Array will be converted to a single, large NumPy array. On a single
machine, you'll likely run out of RAM and crash the program. On a distributed
cluster, all the workers will send their data to a single machine and crash it.

Some scikit-learn models support `incremental learning`_ with the
``.partial_fit`` API. These models can see small batches of dataset and update
their parameters as new data arrives.
:class:`dask_ml.wrappers.Incremental` provides a bridge between Dask and
Scikit-Learn estimators supporting the ``partial_fit`` API. You wrap the
underlying estimator in ``Incremental``. Dask-ML will sequentially pass each
block of a Dask Array to the underlying estimator's ``partial_fit`` method.

.. code-block:: python
.. _incremental.blockwise-metaestimator:

for X_block, y_block in iterator_of_numpy_arrays:
est.partial_fit(X_block, y_block)
Incremental Meta-estimator
--------------------------

This block-wise learning fits nicely with Dask's block-wise nature: Dask
arrays are composed of many smaller NumPy arrays. Dask dataframes and arrays
provides an intuitive way to preprocess your data and then intuitively send
that data to an incremental model piece by piece. Dask-ML will hide the
``.partial_fit`` mechanics from you, so that the usual ``.fit`` API will work
on larger-than-memory datasets. These wrappers can be dropped into a
:class:`sklearn.pipeline.Pipeline` just like normal. In Dask-ml, all of these
estimators are prefixed with ``Partial``, e.g. :class:`PartialSGDClassifier`.
.. currentmodule:: dask_ml

.. autosummary::
wrappers.Incremental

.. note::
:class:`dask_ml.wrappers.Incremental` is a meta-estimator (an estimator that
takes another estimator) that bridges scikit-learn estimators expecting
NumPy arrays, and users with large Dask Arrays.

While these wrappers are useful for fitting on larger than memory datasets
they do not offer any kind of parallelism while training. Calls to
``.fit()`` will be entirely sequential.
Each *block* of a Dask Array is fed to the underlying estiamtor's
``partial_fit`` method. The training is entirely sequential, so you won't
notice massive training time speedups from parallelism. In a distributed
environment, you should notice some speedup from avoiding extra IO, and the
fact that models are typically much smaller than data, and so faster to move
between machines.

Example
-------

.. ipython:: python

from dask_ml.linear_model import PartialSGDRegressor
from dask_ml.datasets import make_classification
X, y = make_classification(n_samples=1000, chunks=500)
est = PartialSGDRegressor()
est.fit(X, y)
from dask_ml.wrappers import Incremental
from sklearn.linear_model import SGDClassifier

X, y = make_classification(chunks=25)
X

estimator = SGDClassifier(random_state=10)
clf = Incremental(estimator, classes=[0, 1])
clf.fit(X, y)

In this example, we make a (small) random Dask Array. It has 100 samples,
broken in the 4 blocks of 25 samples each. The chunking is only along the
first axis (the samples). There is no chunking along the features.

You instantite the underlying estimator as usual. It really is just a
scikit-learn compatible estimator, and will be trained normally via its
``partial_fit``.

When wrapping the estimator in :class:`Incremental`, you need to pass any
keyword arguments that are expected by the underlying ``partial_fit`` method.
With :class:`sklearn.linear_model.SGDClassifier`, we're required to provide
the list of unique ``classes`` in ``y``.

Notice that we call the regular ``.fit`` method for training. Dask-ML takes
care of passing each block to the underlying estimator for you.

.. _incremental learning: http://scikit-learn.org/stable/modules/scaling_strategies.html#incremental-learning
30 changes: 3 additions & 27 deletions docs/source/modules/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,8 @@ Dask-ML provides drop-in replacements for grid and randomized search.
Meta-estimators for scikit-learn
================================

dask-ml provides some meta-estimators parallelize certain components.
dask-ml provides some meta-estimators that help use regular scikit-learn
compatible estimators with Dask arrays.

.. currentmodule:: dask_ml

Expand All @@ -72,32 +73,7 @@ dask-ml provides some meta-estimators parallelize certain components.
:template: class.rst

wrappers.ParallelPostFit


Incremental Learning
====================

.. currentmodule:: dask_ml

Some scikit-learn estimators support out-of-core training through the
``partial_fit`` method. The following estimators wrap those scikit-learn
estimators, allowing them to be used in Pipelines and on Dask arrays and
dataframes. Training will still be serial, so these will not benefit from
a parallel or distributed training any more than the underlying estimator.

.. autosummary::
:toctree: generated/
:template: class.rst

cluster.PartialMiniBatchKMeans
linear_model.PartialPassiveAggressiveClassifier
linear_model.PartialPassiveAggressiveRegressor
linear_model.PartialPerceptron
linear_model.PartialSGDClassifier
linear_model.PartialSGDRegressor
naive_bayes.PartialBernoulliNB
naive_bayes.PartialMultinomialNB

wrappers.Incremental

:mod:`dask_ml.cluster`: Clustering
==================================
Expand Down
4 changes: 4 additions & 0 deletions tests/linear_model/test_neural_network.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import pytest

from sklearn import neural_network as nn_
from dask_ml import neural_network as nn

from dask_ml.utils import assert_estimator_equal


@pytest.mark.filterwarnings("ignore::FutureWarning")
class TestMLPClassifier(object):

def test_basic(self, single_chunk_classification):
Expand All @@ -15,6 +18,7 @@ def test_basic(self, single_chunk_classification):
assert_estimator_equal(a, b)


@pytest.mark.filterwarnings("ignore::FutureWarning")
class TestMLPRegressor(object):

def test_basic(self, single_chunk_classification):
Expand Down
4 changes: 4 additions & 0 deletions tests/linear_model/test_passive_aggressive.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import pytest

from sklearn import linear_model as lm_
from dask_ml import linear_model as lm

from dask_ml.utils import assert_estimator_equal


@pytest.mark.filterwarnings("ignore:'Partial:FutureWarning")
class TestPassiveAggressiveClassifier(object):

def test_basic(self, single_chunk_classification):
Expand All @@ -18,6 +21,7 @@ def test_basic(self, single_chunk_classification):
assert_estimator_equal(a, b, exclude=['loss_function_'])


@pytest.mark.filterwarnings("ignore:'Partial:FutureWarning")
class TestPassiveAggressiveRegressor(object):

def test_basic(self, single_chunk_regression):
Expand Down
3 changes: 3 additions & 0 deletions tests/linear_model/test_perceptron.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import pytest

from sklearn.linear_model import Perceptron
from dask_ml.linear_model import PartialPerceptron

from dask_ml.utils import assert_estimator_equal


@pytest.mark.filterwarnings("ignore:'Partial:FutureWarning")
class TestPerceptron(object):

def test_basic(self, single_chunk_classification):
Expand Down
Loading