Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DOC, TST: Wrapping of PyTorch models #699

Merged
merged 43 commits into from
Jul 29, 2020

Conversation

stsievert
Copy link
Member

@stsievert stsievert commented Jul 15, 2020

What does this PR implement?
It provides wrappers for Keras/PyTorch models, primarily aimed at model selection. This PR wraps PyTorch/Keras by relying on SciKeras/Skorch.

Of course, this is a very thin wrapper. I think it's warranted for the following reasons:

  • PyTorch and Keras/Tensorflow are very popular. No other deep learning library comes close (i.e, Chainer/MXNet), and both PyTorch and Keras are about 15x more popular than Scikit-learn on Google Trends.
  • Skorch is not suited for model selection (too much printing, has validation split by default, too many epochs for partial_fit).

References issues/PRs

This PR will be a WIP until adriangb/scikeras#19 is resolved.

edit Now, the Dask-ML documentation shows the following in the sidebar:
Screen Shot 2020-07-31 at 4 45 13 PM
A "Keras" bullet will be added when #713 is merged.

@stsievert stsievert changed the title ENH: Wrap PyTorch/Keras models WIP: ENH: Wrap PyTorch/Keras models Jul 15, 2020
Copy link
Member

@TomAugspurger TomAugspurger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you expand a bit more on the value added by our wrappers? In particular, to what extent can the changes here be pushed upstream? For example

Skorch is not suited for model selection (too much printing, has validation split by default, too many epochs for partial_fit).

Could the Skorch defaults be changed? Could we document appropriate defaults to use instead?

Regardless of how we handle the wrappers, I think there is value in ensuring that our model selection algorithms work with these estimators. So having the tests, and ideally examples at dask-examples will be valuable.

@@ -30,3 +30,7 @@ dependencies:
- pip
- pip:
- pytest-azurepipelines
- tensorflow
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you keep as much as possible in the conda section? Especially for things like tensorflow & torch.

Also... I don't feel great about including these huge dependencies just for a small subset of the library. I like how dask-gateway does things https://github.com/dask/dask-gateway/blob/master/.travis.yml#L37. That's using Travis. Can you see if something similar is possible for azure-pipelines?

Copy link
Member Author

@stsievert stsievert Jul 15, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be okay to only run the PyTorch/Keras tests on the master branch? I'm not sure I can get a commit message trigger working, but a branch trigger looks straightforward.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I think that's fine.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will probably need to remove these here, and add a secondary conda / pip install that only runs on certain jobs.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. That's in posix.yaml. This PR isn't ready for merge; SciKeras should have a new release soon that incorporates the recent changes.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Azure pipelines are installing Scikeras from master right now

dask_ml/wrappers.py Outdated Show resolved Hide resolved
dask_ml/wrappers.py Outdated Show resolved Hide resolved
@mrocklin
Copy link
Member

What would be necessary to support PyTorch/Keras natively? Could we provide a function that called fit/partial_fit on Scikit-Learn estimators and something else on torch/keras models?

@TomAugspurger
Copy link
Member

TomAugspurger commented Jul 15, 2020 via email

@mrocklin
Copy link
Member

mrocklin commented Jul 15, 2020 via email

@stsievert
Copy link
Member Author

Can you expand a bit more on the value added by our wrappers? In particular, to what extent can the changes here be pushed upstream? ... having the tests, and ideally examples at dask-examples will be valuable

That's part of why I opened this PR. I think having tests, an example at dask-examples and another Dask-ML documentation page would suffice to meet my goals. I aim for these implementations to be easy to discover and work well with model selection. I think those needs are met without adding an implementation in dask_ml.wrappers.

I'll rework this PR to be focused on documentation/testing, and open a PR in dask-examples.

@stsievert stsievert changed the title WIP: ENH: Wrap PyTorch/Keras models WIP: DOC, TST: Wrapping of PyTorch/Keras models Jul 15, 2020
@stsievert stsievert changed the title WIP: DOC, TST: Wrapping of PyTorch/Keras models DOC, TST: Wrapping of PyTorch/Keras models Jul 15, 2020
@stsievert
Copy link
Member Author

Could we provide a function that called fit/partial_fit on Scikit-Learn estimators and something else on torch/keras models?

For model selection not really – @TomAugspurger is right. I suppose it would be possible to hack a solution to use arbitrary training functions using distributed variables:

from keras import fit
from distributed import Variable
from sklearn.base import BaseEstimator

def train_model(hidden=4, epochs=10, data: Variable = None, stop: Variable = None):
    model = create_model(hidden=hidden)
    X_train, y_train, X_test, y_test = ...
    for epoch in range(epochs):
        keras.fit(model, X_train, y_train, epochs=1, workers=4)
        score = model.score(X_test, y_test)
        datum = {"score": score, "pf_calls": epoch}
        d = data.get()
        data.set(d + [datum])
        if stop.get():
            break

class FunctionTrainer(BaseEstimator):
    def __init__(self, fn, **kwargs):
        vars(self).update(kwargs)
        self.fn = fn
        self._data = Variable("_data")
        self._stop = Variable("_stop")
        self._pf_calls = 0

    def _wait_for_training_to_complete(self) -> Dict[str, Any]:
        while True:
            data = self._data.get()
            pf_calls = {d["pf_calls"] for d in data}
            if self._pf_calls in pf_calls:
                break
            await asyncio.sleep(0.1)
        return data[-1]

    def _initialize(self):
        self.data.set([])
        self._stop.set(False)
        kwargs = self.get_params()
        client.submit(self.fn, kwargs, data=self._data, stop=self._stop)

    def _stop_training(self):
        self._stop.set(True)

    def partial_fit(self, X, y):
        datum = self._wait_for_training_to_complete()
        self._pf_calls += 1
        return self

    def score(self, X, y):
        datum = self._wait_for_training_to_complete()
        return datum["score"]

We could make this work with Dask-ML's model selection; it'd have to call FunctionTrainer._stop_training when it kills a model.

It's easier for training a single model. The PyTorch/Keras implementation would be similar to dask-glm: the model and optimization would be reside client side, and the Dask workers would be tasked with computing the gradient.

docs/source/keras.rst Outdated Show resolved Hide resolved
docs/source/keras.rst Outdated Show resolved Hide resolved
docs/source/keras.rst Outdated Show resolved Hide resolved
@@ -30,3 +30,7 @@ dependencies:
- pip
- pip:
- pytest-azurepipelines
- tensorflow
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will probably need to remove these here, and add a secondary conda / pip install that only runs on certain jobs.

Copy link
Member

@TomAugspurger TomAugspurger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking pretty good @stsievert. I think there's a -y missing in the conda install.

I'll also try this out locally a little later.

ci/posix.yaml Outdated Show resolved Hide resolved
ci/posix.yaml Outdated Show resolved Hide resolved
stsievert and others added 3 commits July 24, 2020 14:47
Co-authored-by: Tom Augspurger <TomAugspurger@users.noreply.github.com>
@stsievert stsievert changed the title DOC, TST: Wrapping of PyTorch/Keras models DOC, TST: Wrapping of PyTorch models Jul 26, 2020
@stsievert
Copy link
Member Author

I've put Keras in a separate PR, #713. I think it needs some more work, and I don't think it should block this PR. For more detail on the issues, see #713 (comment).

Now this PR focuses on documenting a PyTorch wrapper and reorganizing the documentation.

@stsievert
Copy link
Member Author

Are we particularly tied to isort 4.3.21? I can not get the pytest.importorskip and the following torch imports to work with isort 4.3.21. It works under isort >= 5 because isort >=5.0.0 supports action comments (source) like "isort: skip" or "isort: split."

@TomAugspurger
Copy link
Member

Are we particularly tied to isort 4.3.21?

No, but I don't think we'd want those anyway. 5b6e20c will hopefully work.

Copy link
Member

@TomAugspurger TomAugspurger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ci/posix.yaml Outdated Show resolved Hide resolved
ci/posix.yaml Show resolved Hide resolved
@TomAugspurger
Copy link
Member

Ah @stsievert I think that torch & skorch are being installed into the base conda environment. I think the tests are run in the dask-ml-test env.

@stsievert
Copy link
Member Author

The Windows CI is failing on the model selection tests named test_small and test_warns_scores_per_fit. The traceback reports TimeoutError or CancelledError, and they pass on my non-Windows machine.

Copy link
Member

@TomAugspurger TomAugspurger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great, thanks. I see that error on 25-50% of the CI runs.

I'll push a commit uncommenting the condition.

ci/posix.yaml Outdated Show resolved Hide resolved
@TomAugspurger TomAugspurger merged commit 5c3179e into dask:master Jul 29, 2020
@TomAugspurger
Copy link
Member

Thanks @stsievert!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants