-
-
Notifications
You must be signed in to change notification settings - Fork 256
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Provide wrappers for popular ML libraries #696
Comments
I think this is possible with these wrappers:
edit these libraries are discussed below:
SciKeras and Skorch are now mentioned in Dask-ML's documentation on wrappers (see 1 and 2). |
If you're looking for a way to make Keras models conform to the scikit-learn API, check out SciKeras (full disclosure: I'm the author) |
Thanks for the link Adrian. To minimize our maintenance burden, I'd aim for the goal that our On top of that, we have the additional burden of these models needing to work well with distributed's serialization. To the extent possible, that functionality should be in the projects themselves (making Keras models picklable) or in distributed. |
How hard is it to support PyTorch/Keras fit/predict APIs? If this is as simple as making a function like the following, then I would be in favor def fit(estimator, X, y=None):
if hasattr(estimator, "fit"):
return estimator.fit(X, y)
elif hasattr(estimator, ...): # pytorch-like
return ...
elif hasattr(estimator, ...): # keras-like
return ... |
For serialization I think that we have a decent Pytorch serializer in distributed (early work from @stsievert if I recall correctly). I don't think that we have anything for Keras today. |
Serialization is also maybe something that we could ask for help from the RAPIDS folks like @quasiben @jakirkham @pentschev . It's not RAPIDS obviously, but these are often GPU related and that team is familiar with these sorts of issues. |
I'm not 100% sure what the goal is here (I just came from the discussion in tensorflow/tensorflow#39609) but SciKeras adds serialization support to the Keras models it wraps. Ex: from scikeras import KerasClassifier
keras_model = ... # some keras model object, can be Sequential or Functional
wrapped_model = KerasRegressor(keras_model) # a serializable, scikit-learn api compliant estimator So I guess you could just tell your users to wrap their Keras models before using them with dask-ml? |
Serialization is definitely something RAPIDS cares about. cc @JohnZed maybe pytorch serialization is something cuML would also care about |
Current pytorch serialization is here:
https://github.com/dask/distributed/blob/master/distributed/protocol/torch.py
It looks like it forces things to numpy though, and so may not be
GPU-optimized.
Rather than scikeras I'm still curious if we can make things more
torch/tf/keras-native cheaply
…On Mon, Jul 13, 2020 at 8:35 AM Benjamin Zaitlen ***@***.***> wrote:
Serialization is definitely something RAPIDS cares about. scikeras looks
interesting -- @adriangb <https://github.com/adriangb> do you know if it
forces a host to device transfer ? Does it support the
__cuda_array_interface__ ? If so, I believe things are a lot easier for
us.
cc @JohnZed <https://github.com/JohnZed> maybe pytorch serialization is
something cuML would also care about
—
You are receiving this because you commented.
Reply to this email directly, view it on GitHub
<#696 (comment)>, or
unsubscribe
<https://github.com/notifications/unsubscribe-auth/AACKZTBFNULDR66W3MEA6XDR3MSVHANCNFSM4OX6CD2Q>
.
|
That's the plan. To do that, models need to support serialization and implement
PyTorch has serialization support, even though they recently tried to remove it! pytorch/pytorch#38597 Skorch wraps the PyTorch, and it looks like the support GPUs: skorch/net.py#L1608. |
To be honest, I am not familiar with these terms. All SciKeras does is implement
Will take a look tonight! |
How difficult would it be to implement pickling (like Matt did for PyTorch ( pytorch/pytorch#9184 )) for Keras as well? There's a lot of value gained by supporting standard Python protocols. Not to say there may not be additional gains with Dask serialization. Just that having this standard protocol working would make interop with various distributed computing libraries (including Dask) easier. |
SciKeras has an implementation at scikeras/wrappers.py#L87. There's currently an open PR to merge this into Tensorflow/Keras master: tensorflow/tensorflow#39609 Does that answer your question? |
Here's two more PyTorch wrappers:
|
It'd be convenient to provide support for use of Keras or PyTorch models in model selection. There are two issues:
I'm imaging this interface:
Related issues/PRs
Same complaint in dask/distributed: dask/distributed#3873
The text was updated successfully, but these errors were encountered: