Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Provide wrappers for popular ML libraries #696

Open
stsievert opened this issue Jul 12, 2020 · 14 comments
Open

Provide wrappers for popular ML libraries #696

stsievert opened this issue Jul 12, 2020 · 14 comments

Comments

@stsievert
Copy link
Member

stsievert commented Jul 12, 2020

It'd be convenient to provide support for use of Keras or PyTorch models in model selection. There are two issues:

  1. Keras/PyTorch models don't conform to the Scikit-learn API.
  2. Keras models are not pickle-able.

I'm imaging this interface:

from torchvision.models import resnet18
import torch.optim as optim
from dask_ml.wrappers import PyTorchClassifier

pytorch_model = resnet18()
sklearn_model = SkorchClassifier(
    model=pytorch_model,
    model__alpha=1e-2,  # if resnet18 had a kwarg `alpha`
    optimizer=optim.SGD,
    optimizer__lr=0.1,
)

Related issues/PRs
Same complaint in dask/distributed: dask/distributed#3873

@stsievert
Copy link
Member Author

stsievert commented Jul 12, 2020

I think this is possible with these wrappers:

edit these libraries are discussed below:

  • adadamp, which provides distributed training and a Scikit-learn API for PyTorch models.
  • saturncloud/dask-pytorch-ddp, which allows usage of Dask clusters with native PyTorch distributed code.
  • SciKeras, which provides a Scikit-learn API to Keras.

SciKeras and Skorch are now mentioned in Dask-ML's documentation on wrappers (see 1 and 2).

@adriangb
Copy link

If you're looking for a way to make Keras models conform to the scikit-learn API, check out SciKeras (full disclosure: I'm the author)

@TomAugspurger
Copy link
Member

Thanks for the link Adrian. To minimize our maintenance burden, I'd aim for the goal that our model_selection estimators work with any model implementing the scikit-learn interface, and encourage the development / use of wrappers like skorch and SciKeras.

On top of that, we have the additional burden of these models needing to work well with distributed's serialization. To the extent possible, that functionality should be in the projects themselves (making Keras models picklable) or in distributed.

@mrocklin
Copy link
Member

How hard is it to support PyTorch/Keras fit/predict APIs? If this is as simple as making a function like the following, then I would be in favor

def fit(estimator, X, y=None):
    if hasattr(estimator, "fit"):
        return estimator.fit(X, y)
    elif hasattr(estimator, ...): #  pytorch-like
        return ...
    elif hasattr(estimator, ...): # keras-like
        return ...

@mrocklin
Copy link
Member

For serialization I think that we have a decent Pytorch serializer in distributed (early work from @stsievert if I recall correctly). I don't think that we have anything for Keras today.

@mrocklin
Copy link
Member

Serialization is also maybe something that we could ask for help from the RAPIDS folks like @quasiben @jakirkham @pentschev . It's not RAPIDS obviously, but these are often GPU related and that team is familiar with these sorts of issues.

@adriangb
Copy link

adriangb commented Jul 13, 2020

I'm not 100% sure what the goal is here (I just came from the discussion in tensorflow/tensorflow#39609) but SciKeras adds serialization support to the Keras models it wraps. Ex:

from scikeras import KerasClassifier

keras_model = ...  # some keras model object, can be Sequential or Functional

wrapped_model = KerasRegressor(keras_model)  # a serializable, scikit-learn api compliant estimator

So I guess you could just tell your users to wrap their Keras models before using them with dask-ml?

@quasiben
Copy link
Member

Serialization is definitely something RAPIDS cares about. scikeras looks interesting -- @adriangb do you know if it forces a host to device transfer ? Does it support the __cuda_array_interface__ ? If so, I believe things are a lot easier for us.

cc @JohnZed maybe pytorch serialization is something cuML would also care about

@mrocklin
Copy link
Member

mrocklin commented Jul 13, 2020 via email

@stsievert
Copy link
Member Author

So I guess you could just tell your users to wrap their Keras models before using them with dask-ml?

That's the plan. To do that, models need to support serialization and implement partial_fit (see adriangb/scikeras#17).

pytorch serialization is something cuML would also care about

may not be GPU-optimized.

PyTorch has serialization support, even though they recently tried to remove it! pytorch/pytorch#38597 Skorch wraps the PyTorch, and it looks like the support GPUs: skorch/net.py#L1608.

@adriangb
Copy link

adriangb commented Jul 13, 2020

do you know if it forces a host to device transfer ? Does it support the __cuda_array_interface__ ? If so, I believe things are a lot easier for us.

To be honest, I am not familiar with these terms. All SciKeras does is implement copy.deepcopy and pickle compatible serialization. It does not have a __cuda_array_interface__ method, so I think the answer is no.

models need to support serialization and implement partial_fit (see adriangb/scikeras#17).

Will take a look tonight!

@jakirkham
Copy link
Member

How difficult would it be to implement pickling (like Matt did for PyTorch ( pytorch/pytorch#9184 )) for Keras as well? There's a lot of value gained by supporting standard Python protocols. Not to say there may not be additional gains with Dask serialization. Just that having this standard protocol working would make interop with various distributed computing libraries (including Dask) easier.

@stsievert
Copy link
Member Author

stsievert commented Jul 13, 2020

How difficult would it be to implement pickling ... for Keras as well?

SciKeras has an implementation at scikeras/wrappers.py#L87. There's currently an open PR to merge this into Tensorflow/Keras master: tensorflow/tensorflow#39609

Does that answer your question?

@stsievert
Copy link
Member Author

Here's two more PyTorch wrappers:

  1. adadamp, which provides usage of Dask clusters with PyTorch models and presents a Scikit-learn interface
  2. saturncloud/dask-pytorch-ddp, which allows use of a Dask cluster with PyTorch distributed code.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

6 participants