New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] Enable Skorch+Dask-ML #5748
base: main
Are you sure you want to change the base?
Conversation
Can one of the admins verify this patch? |
add to allowlist |
Minimal Example of the error that this PR fixes: from distributed import Client
from distributed.protocol import deserialize, serialize
from torch import nn
from skorch import NeuralNetClassifier
client = Client(processes=True,n_workers=2)
class MyModule(nn.Module):
def __init__(self, num_units=10):
super(MyModule, self).__init__()
self.dense0 = nn.Linear(20, num_units)
def forward(self, X, **kwargs):
return self.dense0(X)
net = NeuralNetClassifier(
MyModule,
max_epochs=10,
iterator_train__shuffle=True,
)
def test_serialize_skorch(net):
net = net.initialize()
return deserialize(*serialize(net))
output = client.run(test_serialize_skorch, net) Trace on Main: distributed.worker - WARNING - Run Failed
Function: test_serialize_skorch
args: (<class 'skorch.classifier.NeuralNetClassifier'>[initialized](
module_=MyModule(
(dense0): Linear(in_features=20, out_features=10, bias=True)
),
))
kwargs: {}
Traceback (most recent call last):
File "/home/nfs/vjawa/skorch_dask_cuda/distributed/distributed/worker.py", line 4648, in run
result = function(*args, **kwargs)
File "/tmp/ipykernel_9764/775506166.py", line 24, in test_serialize_skorch
File "/home/nfs/vjawa/skorch_dask_cuda/distributed/distributed/protocol/serialize.py", line 417, in deserialize
return loads(header, frames)
File "/home/nfs/vjawa/skorch_dask_cuda/distributed/distributed/protocol/serialize.py", line 180, in serialization_error_loads
raise TypeError(msg)
TypeError: Could not serialize object of type NeuralNetClassifier.
Traceback (most recent call last):
File "/home/nfs/vjawa/skorch_dask_cuda/distributed/distributed/protocol/pickle.py", line 49, in dumps
result = pickle.dumps(x, **dump_kwargs)
File "/datasets/vjawa/miniconda3/envs/cuml-torch-skorch/lib/python3.8/site-packages/skorch/net.py", line 2048, in __getstate__
torch.save(cuda_attrs, f)
File "/datasets/vjawa/miniconda3/envs/cuml-torch-skorch/lib/python3.8/site-packages/torch/serialization.py", line 379, in save
_save(obj, opened_zipfile, pickle_module, pickle_protocol)
File "/datasets/vjawa/miniconda3/envs/cuml-torch-skorch/lib/python3.8/site-packages/torch/serialization.py", line 484, in _save
pickler.dump(obj)
_pickle.PicklingError: Can't pickle <class '__main__.MyModule'>: attribute lookup MyModule on __main__ failed
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/home/nfs/vjawa/skorch_dask_cuda/distributed/distributed/protocol/serialize.py", line 340, in serialize
header, frames = dumps(x, context=context) if wants_context else dumps(x)
File "/home/nfs/vjawa/skorch_dask_cuda/distributed/distributed/protocol/serialize.py", line 63, in pickle_dumps
frames[0] = pickle.dumps(
File "/home/nfs/vjawa/skorch_dask_cuda/distributed/distributed/protocol/pickle.py", line 60, in dumps
result = cloudpickle.dumps(x, **dump_kwargs)
File "/datasets/vjawa/miniconda3/envs/cuml-torch-skorch/lib/python3.8/site-packages/cloudpickle/cloudpickle_fast.py", line 73, in dumps
cp.dump(obj)
File "/datasets/vjawa/miniconda3/envs/cuml-torch-skorch/lib/python3.8/site-packages/cloudpickle/cloudpickle_fast.py", line 602, in dump
return Pickler.dump(self, obj)
File "/datasets/vjawa/miniconda3/envs/cuml-torch-skorch/lib/python3.8/site-packages/skorch/net.py", line 2048, in __getstate__
torch.save(cuda_attrs, f)
File "/datasets/vjawa/miniconda3/envs/cuml-torch-skorch/lib/python3.8/site-packages/torch/serialization.py", line 379, in save
_save(obj, opened_zipfile, pickle_module, pickle_protocol)
File "/datasets/vjawa/miniconda3/envs/cuml-torch-skorch/lib/python3.8/site-packages/torch/serialization.py", line 484, in _save
pickler.dump(obj)
_pickle.PicklingError: Can't pickle <class '__main__.MyModule'>: attribute lookup MyModule on __main__ failed
distributed.worker - WARNING - Run Failed
Function: test_serialize_skorch
args: (<class 'skorch.classifier.NeuralNetClassifier'>[initialized](
module_=MyModule(
(dense0): Linear(in_features=20, out_features=10, bias=True)
),
))
kwargs: {}
Traceback (most recent call last):
File "/home/nfs/vjawa/skorch_dask_cuda/distributed/distributed/worker.py", line 4648, in run
result = function(*args, **kwargs)
File "/tmp/ipykernel_9764/775506166.py", line 24, in test_serialize_skorch
File "/home/nfs/vjawa/skorch_dask_cuda/distributed/distributed/protocol/serialize.py", line 417, in deserialize
return loads(header, frames)
File "/home/nfs/vjawa/skorch_dask_cuda/distributed/distributed/protocol/serialize.py", line 180, in serialization_error_loads
raise TypeError(msg)
TypeError: Could not serialize object of type NeuralNetClassifier.
Traceback (most recent call last):
File "/home/nfs/vjawa/skorch_dask_cuda/distributed/distributed/protocol/pickle.py", line 49, in dumps
result = pickle.dumps(x, **dump_kwargs)
File "/datasets/vjawa/miniconda3/envs/cuml-torch-skorch/lib/python3.8/site-packages/skorch/net.py", line 2048, in __getstate__
torch.save(cuda_attrs, f)
File "/datasets/vjawa/miniconda3/envs/cuml-torch-skorch/lib/python3.8/site-packages/torch/serialization.py", line 379, in save
_save(obj, opened_zipfile, pickle_module, pickle_protocol)
File "/datasets/vjawa/miniconda3/envs/cuml-torch-skorch/lib/python3.8/site-packages/torch/serialization.py", line 484, in _save
pickler.dump(obj)
_pickle.PicklingError: Can't pickle <class '__main__.MyModule'>: attribute lookup MyModule on __main__ failed
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/home/nfs/vjawa/skorch_dask_cuda/distributed/distributed/protocol/serialize.py", line 340, in serialize
header, frames = dumps(x, context=context) if wants_context else dumps(x)
File "/home/nfs/vjawa/skorch_dask_cuda/distributed/distributed/protocol/serialize.py", line 63, in pickle_dumps
frames[0] = pickle.dumps(
File "/home/nfs/vjawa/skorch_dask_cuda/distributed/distributed/protocol/pickle.py", line 60, in dumps
result = cloudpickle.dumps(x, **dump_kwargs)
File "/datasets/vjawa/miniconda3/envs/cuml-torch-skorch/lib/python3.8/site-packages/cloudpickle/cloudpickle_fast.py", line 73, in dumps
cp.dump(obj)
File "/datasets/vjawa/miniconda3/envs/cuml-torch-skorch/lib/python3.8/site-packages/cloudpickle/cloudpickle_fast.py", line 602, in dump
return Pickler.dump(self, obj)
File "/datasets/vjawa/miniconda3/envs/cuml-torch-skorch/lib/python3.8/site-packages/skorch/net.py", line 2048, in __getstate__
torch.save(cuda_attrs, f)
File "/datasets/vjawa/miniconda3/envs/cuml-torch-skorch/lib/python3.8/site-packages/torch/serialization.py", line 379, in save
_save(obj, opened_zipfile, pickle_module, pickle_protocol)
File "/datasets/vjawa/miniconda3/envs/cuml-torch-skorch/lib/python3.8/site-packages/torch/serialization.py", line 484, in _save
pickler.dump(obj)
_pickle.PicklingError: Can't pickle <class '__main__.MyModule'>: attribute lookup MyModule on __main__ failed
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
Input In [1], in <module>
23 net = net.initialize()
24 return deserialize(*serialize(net))
---> 26 output = client.run(test_serialize_skorch, net)
File ~/skorch_dask_cuda/distributed/distributed/client.py:2750, in Client.run(self, function, workers, wait, nanny, on_error, *args, **kwargs)
2667 def run(
2668 self,
2669 function,
(...)
2675 **kwargs,
2676 ):
2677 """
2678 Run a function on all workers outside of task scheduling system
2679
(...)
2748 >>> c.run(print_state, wait=False) # doctest: +SKIP
2749 """
-> 2750 return self.sync(
2751 self._run,
2752 function,
2753 *args,
2754 workers=workers,
2755 wait=wait,
2756 nanny=nanny,
2757 on_error=on_error,
2758 **kwargs,
2759 )
File ~/skorch_dask_cuda/distributed/distributed/utils.py:309, in SyncMethodMixin.sync(self, func, asynchronous, callback_timeout, *args, **kwargs)
307 return future
308 else:
--> 309 return sync(
310 self.loop, func, *args, callback_timeout=callback_timeout, **kwargs
311 )
File ~/skorch_dask_cuda/distributed/distributed/utils.py:363, in sync(loop, func, callback_timeout, *args, **kwargs)
361 if error[0]:
362 typ, exc, tb = error[0]
--> 363 raise exc.with_traceback(tb)
364 else:
365 return result[0]
File ~/skorch_dask_cuda/distributed/distributed/utils.py:348, in sync.<locals>.f()
346 if callback_timeout is not None:
347 future = asyncio.wait_for(future, callback_timeout)
--> 348 result[0] = yield future
349 except Exception:
350 error[0] = sys.exc_info()
File /datasets/vjawa/miniconda3/envs/cuml-torch-skorch/lib/python3.8/site-packages/tornado/gen.py:762, in Runner.run(self)
759 exc_info = None
761 try:
--> 762 value = future.result()
763 except Exception:
764 exc_info = sys.exc_info()
File ~/skorch_dask_cuda/distributed/distributed/client.py:2655, in Client._run(self, function, nanny, workers, wait, on_error, *args, **kwargs)
2652 continue
2654 if on_error == "raise":
-> 2655 raise exc
2656 elif on_error == "return":
2657 results[key] = exc
Input In [1], in test_serialize_skorch()
22 def test_serialize_skorch(net):
23 net = net.initialize()
---> 24 return deserialize(*serialize(net))
File ~/skorch_dask_cuda/distributed/distributed/protocol/serialize.py:417, in deserialize()
412 raise TypeError(
413 "Data serialized with %s but only able to deserialize "
414 "data with %s" % (name, str(list(deserializers)))
415 )
416 dumps, loads, wants_context = families[name]
--> 417 return loads(header, frames)
File ~/skorch_dask_cuda/distributed/distributed/protocol/serialize.py:180, in serialization_error_loads()
178 def serialization_error_loads(header, frames):
179 msg = "\n".join([ensure_bytes(frame).decode("utf8") for frame in frames])
--> 180 raise TypeError(msg)
TypeError: Could not serialize object of type NeuralNetClassifier.
Traceback (most recent call last):
File "/home/nfs/vjawa/skorch_dask_cuda/distributed/distributed/protocol/pickle.py", line 49, in dumps
result = pickle.dumps(x, **dump_kwargs)
File "/datasets/vjawa/miniconda3/envs/cuml-torch-skorch/lib/python3.8/site-packages/skorch/net.py", line 2048, in __getstate__
torch.save(cuda_attrs, f)
File "/datasets/vjawa/miniconda3/envs/cuml-torch-skorch/lib/python3.8/site-packages/torch/serialization.py", line 379, in save
_save(obj, opened_zipfile, pickle_module, pickle_protocol)
File "/datasets/vjawa/miniconda3/envs/cuml-torch-skorch/lib/python3.8/site-packages/torch/serialization.py", line 484, in _save
pickler.dump(obj)
_pickle.PicklingError: Can't pickle <class '__main__.MyModule'>: attribute lookup MyModule on __main__ failed
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/home/nfs/vjawa/skorch_dask_cuda/distributed/distributed/protocol/serialize.py", line 340, in serialize
header, frames = dumps(x, context=context) if wants_context else dumps(x)
File "/home/nfs/vjawa/skorch_dask_cuda/distributed/distributed/protocol/serialize.py", line 63, in pickle_dumps
frames[0] = pickle.dumps(
File "/home/nfs/vjawa/skorch_dask_cuda/distributed/distributed/protocol/pickle.py", line 60, in dumps
result = cloudpickle.dumps(x, **dump_kwargs)
File "/datasets/vjawa/miniconda3/envs/cuml-torch-skorch/lib/python3.8/site-packages/cloudpickle/cloudpickle_fast.py", line 73, in dumps
cp.dump(obj)
File "/datasets/vjawa/miniconda3/envs/cuml-torch-skorch/lib/python3.8/site-packages/cloudpickle/cloudpickle_fast.py", line 602, in dump
return Pickler.dump(self, obj)
File "/datasets/vjawa/miniconda3/envs/cuml-torch-skorch/lib/python3.8/site-packages/skorch/net.py", line 2048, in __getstate__
torch.save(cuda_attrs, f)
File "/datasets/vjawa/miniconda3/envs/cuml-torch-skorch/lib/python3.8/site-packages/torch/serialization.py", line 379, in save
_save(obj, opened_zipfile, pickle_module, pickle_protocol)
File "/datasets/vjawa/miniconda3/envs/cuml-torch-skorch/lib/python3.8/site-packages/torch/serialization.py", line 484, in _save
pickler.dump(obj)
_pickle.PicklingError: Can't pickle <class '__main__.MyModule'>: attribute lookup MyModule on __main__ failed Output on PR: {'tcp://127.0.0.1:39677': <class 'skorch.classifier.NeuralNetClassifier'>[initialized](
module_=MyModule(
(dense0): Linear(in_features=20, out_features=10, bias=True)
),
), 'tcp://127.0.0.1:43891': <class 'skorch.classifier.NeuralNetClassifier'>[initialized](
module_=MyModule(
(dense0): Linear(in_features=20, out_features=10, bias=True)
),
)} |
Unit Test Results 18 files ± 0 18 suites ±0 9h 41m 7s ⏱️ + 2m 18s For more details on these failures and errors, see this check. Results for commit 123ca52. ± Comparison against base commit 3902429. ♻️ This comment has been updated with latest results. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks Vibhu! 😄
Have some suggestions below to support zero-copy pickling (assuming skorch uses objects that can take advantage of that).
Also have some questions about module
and the special handling there.
has_module = hasattr(x, "module_") | ||
headers = {"has_module": has_module} | ||
if has_module: | ||
module = x.__dict__.pop("module_") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Curious why module_
can't be pickled on its own. Is there any more info on the issues encountered by leaving this?
Also any downside to (temporarily) modifying a user-provided object here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Curious why module_ can't be pickled on its own. Is there any more info on the issues encountered by leaving this?
So module's is an interactively defined class on client so its namespace is often __main__
. For eg. __main__. MyModule
.
Pickle has problems pickling when interactively defined classes when they are set as an attributes of another object. as it tries to look up the class in the namespace. See eg for trace.
By pickling it on its own we are able to serialize successfully
Also any downside to (temporarily) modifying a user-provided object here?
The only side effect i can think of is if the class is redefined in the worker's name-space causing undefined behavior while de-serializing on the worker. I doubt that will really happen in real workflows.
FWIW, I have added a test to verify that at-least for the class is the same after deserialization.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this just an issue with pickle
? Does cloudpickle
run into this issue or does it work ok?
Thanks a lot for the comprehensive review. Helped me learn a lot about no-copy pickling and protocol 5. Very informative and useful. Sorry for the delay in updating the PR , got side tracked by other stuff. I think its ready for another review. Have tried my best to address your reviews. Feel free to raise any more changes/points of clarifications. |
deserialized_net = list(client.run(test_serialize_skorch, net).values())[0] | ||
assert isinstance(deserialized_net.module_, MyModule) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Was the best test i could come up with for testing on the worker . Please let me know is there is a better way to check serialization in a different process.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we would benefit from a roundtrip serialization test like some of the others in this directory (without a cluster) to make sure that is working as expected. Know that doesn't show the error per-se, but it will help catch other errors in the future
In terms of testing a worker, would take a look at some of the other tests. Maybe like this one? Then adapt that to your use case. We shouldn't need to manually do the serialization ourselves, but instead rely on Dask to do that for us and merely check that things work as expected
Thanks Vibhu! 😄 Mainly just have questions on the existing threads above |
This PR tries to enable
Skorch
to run successfully withDask-ML
This PR follows the suggestion that @TomAugspurger made on dask/dask-ml#549 in this comment.
That seems to fixes issues on local testing with GridSearchCV.
We still have issues with
HyperBandCV
. (Looking into that.) :-(Example Workflow that this PR enables.
I am still trying to come up with a cleaner test case to ensure the problem is truly fixed.
pre-commit run --all-files
Related issues:
dask/dask-ml#549
dask/dask-ml#696
dask/dask-ml#664
dask/dask-ml#892