Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Enable Skorch+Dask-ML #5748

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

VibhuJawa
Copy link

@VibhuJawa VibhuJawa commented Feb 2, 2022

This PR tries to enable Skorch to run successfully with Dask-ML

This PR follows the suggestion that @TomAugspurger made on dask/dask-ml#549 in this comment.

That seems to fixes issues on local testing with GridSearchCV.

We still have issues with HyperBandCV . (Looking into that.) :-(

Example Workflow that this PR enables.

I am still trying to come up with a cleaner test case to ensure the problem is truly fixed.

  • Closes #xxxx
  • Tests added
  • Passes pre-commit run --all-files

Related issues:
dask/dask-ml#549
dask/dask-ml#696
dask/dask-ml#664
dask/dask-ml#892

@GPUtester
Copy link
Collaborator

Can one of the admins verify this patch?

@quasiben
Copy link
Member

quasiben commented Feb 2, 2022

add to allowlist

@VibhuJawa
Copy link
Author

VibhuJawa commented Feb 2, 2022

Minimal Example of the error that this PR fixes:

from distributed import Client
from distributed.protocol import deserialize, serialize

from torch import nn
from skorch import NeuralNetClassifier
client = Client(processes=True,n_workers=2)

class MyModule(nn.Module):
    def __init__(self, num_units=10):
        super(MyModule, self).__init__()
        self.dense0 = nn.Linear(20, num_units)

    def forward(self, X, **kwargs):
        return self.dense0(X)

net = NeuralNetClassifier(
    MyModule,
    max_epochs=10,
    iterator_train__shuffle=True,
    )

def test_serialize_skorch(net):
    net = net.initialize()
    return deserialize(*serialize(net))

output = client.run(test_serialize_skorch, net)   

Trace on Main:

distributed.worker - WARNING - Run Failed
Function: test_serialize_skorch
args:     (<class 'skorch.classifier.NeuralNetClassifier'>[initialized](
  module_=MyModule(
    (dense0): Linear(in_features=20, out_features=10, bias=True)
  ),
))
kwargs:   {}
Traceback (most recent call last):
  File "/home/nfs/vjawa/skorch_dask_cuda/distributed/distributed/worker.py", line 4648, in run
    result = function(*args, **kwargs)
  File "/tmp/ipykernel_9764/775506166.py", line 24, in test_serialize_skorch
  File "/home/nfs/vjawa/skorch_dask_cuda/distributed/distributed/protocol/serialize.py", line 417, in deserialize
    return loads(header, frames)
  File "/home/nfs/vjawa/skorch_dask_cuda/distributed/distributed/protocol/serialize.py", line 180, in serialization_error_loads
    raise TypeError(msg)
TypeError: Could not serialize object of type NeuralNetClassifier.
Traceback (most recent call last):
  File "/home/nfs/vjawa/skorch_dask_cuda/distributed/distributed/protocol/pickle.py", line 49, in dumps
    result = pickle.dumps(x, **dump_kwargs)
  File "/datasets/vjawa/miniconda3/envs/cuml-torch-skorch/lib/python3.8/site-packages/skorch/net.py", line 2048, in __getstate__
    torch.save(cuda_attrs, f)
  File "/datasets/vjawa/miniconda3/envs/cuml-torch-skorch/lib/python3.8/site-packages/torch/serialization.py", line 379, in save
    _save(obj, opened_zipfile, pickle_module, pickle_protocol)
  File "/datasets/vjawa/miniconda3/envs/cuml-torch-skorch/lib/python3.8/site-packages/torch/serialization.py", line 484, in _save
    pickler.dump(obj)
_pickle.PicklingError: Can't pickle <class '__main__.MyModule'>: attribute lookup MyModule on __main__ failed

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/nfs/vjawa/skorch_dask_cuda/distributed/distributed/protocol/serialize.py", line 340, in serialize
    header, frames = dumps(x, context=context) if wants_context else dumps(x)
  File "/home/nfs/vjawa/skorch_dask_cuda/distributed/distributed/protocol/serialize.py", line 63, in pickle_dumps
    frames[0] = pickle.dumps(
  File "/home/nfs/vjawa/skorch_dask_cuda/distributed/distributed/protocol/pickle.py", line 60, in dumps
    result = cloudpickle.dumps(x, **dump_kwargs)
  File "/datasets/vjawa/miniconda3/envs/cuml-torch-skorch/lib/python3.8/site-packages/cloudpickle/cloudpickle_fast.py", line 73, in dumps
    cp.dump(obj)
  File "/datasets/vjawa/miniconda3/envs/cuml-torch-skorch/lib/python3.8/site-packages/cloudpickle/cloudpickle_fast.py", line 602, in dump
    return Pickler.dump(self, obj)
  File "/datasets/vjawa/miniconda3/envs/cuml-torch-skorch/lib/python3.8/site-packages/skorch/net.py", line 2048, in __getstate__
    torch.save(cuda_attrs, f)
  File "/datasets/vjawa/miniconda3/envs/cuml-torch-skorch/lib/python3.8/site-packages/torch/serialization.py", line 379, in save
    _save(obj, opened_zipfile, pickle_module, pickle_protocol)
  File "/datasets/vjawa/miniconda3/envs/cuml-torch-skorch/lib/python3.8/site-packages/torch/serialization.py", line 484, in _save
    pickler.dump(obj)
_pickle.PicklingError: Can't pickle <class '__main__.MyModule'>: attribute lookup MyModule on __main__ failed

distributed.worker - WARNING - Run Failed
Function: test_serialize_skorch
args:     (<class 'skorch.classifier.NeuralNetClassifier'>[initialized](
  module_=MyModule(
    (dense0): Linear(in_features=20, out_features=10, bias=True)
  ),
))
kwargs:   {}
Traceback (most recent call last):
  File "/home/nfs/vjawa/skorch_dask_cuda/distributed/distributed/worker.py", line 4648, in run
    result = function(*args, **kwargs)
  File "/tmp/ipykernel_9764/775506166.py", line 24, in test_serialize_skorch
  File "/home/nfs/vjawa/skorch_dask_cuda/distributed/distributed/protocol/serialize.py", line 417, in deserialize
    return loads(header, frames)
  File "/home/nfs/vjawa/skorch_dask_cuda/distributed/distributed/protocol/serialize.py", line 180, in serialization_error_loads
    raise TypeError(msg)
TypeError: Could not serialize object of type NeuralNetClassifier.
Traceback (most recent call last):
  File "/home/nfs/vjawa/skorch_dask_cuda/distributed/distributed/protocol/pickle.py", line 49, in dumps
    result = pickle.dumps(x, **dump_kwargs)
  File "/datasets/vjawa/miniconda3/envs/cuml-torch-skorch/lib/python3.8/site-packages/skorch/net.py", line 2048, in __getstate__
    torch.save(cuda_attrs, f)
  File "/datasets/vjawa/miniconda3/envs/cuml-torch-skorch/lib/python3.8/site-packages/torch/serialization.py", line 379, in save
    _save(obj, opened_zipfile, pickle_module, pickle_protocol)
  File "/datasets/vjawa/miniconda3/envs/cuml-torch-skorch/lib/python3.8/site-packages/torch/serialization.py", line 484, in _save
    pickler.dump(obj)
_pickle.PicklingError: Can't pickle <class '__main__.MyModule'>: attribute lookup MyModule on __main__ failed

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/nfs/vjawa/skorch_dask_cuda/distributed/distributed/protocol/serialize.py", line 340, in serialize
    header, frames = dumps(x, context=context) if wants_context else dumps(x)
  File "/home/nfs/vjawa/skorch_dask_cuda/distributed/distributed/protocol/serialize.py", line 63, in pickle_dumps
    frames[0] = pickle.dumps(
  File "/home/nfs/vjawa/skorch_dask_cuda/distributed/distributed/protocol/pickle.py", line 60, in dumps
    result = cloudpickle.dumps(x, **dump_kwargs)
  File "/datasets/vjawa/miniconda3/envs/cuml-torch-skorch/lib/python3.8/site-packages/cloudpickle/cloudpickle_fast.py", line 73, in dumps
    cp.dump(obj)
  File "/datasets/vjawa/miniconda3/envs/cuml-torch-skorch/lib/python3.8/site-packages/cloudpickle/cloudpickle_fast.py", line 602, in dump
    return Pickler.dump(self, obj)
  File "/datasets/vjawa/miniconda3/envs/cuml-torch-skorch/lib/python3.8/site-packages/skorch/net.py", line 2048, in __getstate__
    torch.save(cuda_attrs, f)
  File "/datasets/vjawa/miniconda3/envs/cuml-torch-skorch/lib/python3.8/site-packages/torch/serialization.py", line 379, in save
    _save(obj, opened_zipfile, pickle_module, pickle_protocol)
  File "/datasets/vjawa/miniconda3/envs/cuml-torch-skorch/lib/python3.8/site-packages/torch/serialization.py", line 484, in _save
    pickler.dump(obj)
_pickle.PicklingError: Can't pickle <class '__main__.MyModule'>: attribute lookup MyModule on __main__ failed

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Input In [1], in <module>
     23     net = net.initialize()
     24     return deserialize(*serialize(net))
---> 26 output = client.run(test_serialize_skorch, net)

File ~/skorch_dask_cuda/distributed/distributed/client.py:2750, in Client.run(self, function, workers, wait, nanny, on_error, *args, **kwargs)
   2667 def run(
   2668     self,
   2669     function,
   (...)
   2675     **kwargs,
   2676 ):
   2677     """
   2678     Run a function on all workers outside of task scheduling system
   2679 
   (...)
   2748     >>> c.run(print_state, wait=False)  # doctest: +SKIP
   2749     """
-> 2750     return self.sync(
   2751         self._run,
   2752         function,
   2753         *args,
   2754         workers=workers,
   2755         wait=wait,
   2756         nanny=nanny,
   2757         on_error=on_error,
   2758         **kwargs,
   2759     )

File ~/skorch_dask_cuda/distributed/distributed/utils.py:309, in SyncMethodMixin.sync(self, func, asynchronous, callback_timeout, *args, **kwargs)
    307     return future
    308 else:
--> 309     return sync(
    310         self.loop, func, *args, callback_timeout=callback_timeout, **kwargs
    311     )

File ~/skorch_dask_cuda/distributed/distributed/utils.py:363, in sync(loop, func, callback_timeout, *args, **kwargs)
    361 if error[0]:
    362     typ, exc, tb = error[0]
--> 363     raise exc.with_traceback(tb)
    364 else:
    365     return result[0]

File ~/skorch_dask_cuda/distributed/distributed/utils.py:348, in sync.<locals>.f()
    346     if callback_timeout is not None:
    347         future = asyncio.wait_for(future, callback_timeout)
--> 348     result[0] = yield future
    349 except Exception:
    350     error[0] = sys.exc_info()

File /datasets/vjawa/miniconda3/envs/cuml-torch-skorch/lib/python3.8/site-packages/tornado/gen.py:762, in Runner.run(self)
    759 exc_info = None
    761 try:
--> 762     value = future.result()
    763 except Exception:
    764     exc_info = sys.exc_info()

File ~/skorch_dask_cuda/distributed/distributed/client.py:2655, in Client._run(self, function, nanny, workers, wait, on_error, *args, **kwargs)
   2652     continue
   2654 if on_error == "raise":
-> 2655     raise exc
   2656 elif on_error == "return":
   2657     results[key] = exc

Input In [1], in test_serialize_skorch()
     22 def test_serialize_skorch(net):
     23     net = net.initialize()
---> 24     return deserialize(*serialize(net))

File ~/skorch_dask_cuda/distributed/distributed/protocol/serialize.py:417, in deserialize()
    412     raise TypeError(
    413         "Data serialized with %s but only able to deserialize "
    414         "data with %s" % (name, str(list(deserializers)))
    415     )
    416 dumps, loads, wants_context = families[name]
--> 417 return loads(header, frames)

File ~/skorch_dask_cuda/distributed/distributed/protocol/serialize.py:180, in serialization_error_loads()
    178 def serialization_error_loads(header, frames):
    179     msg = "\n".join([ensure_bytes(frame).decode("utf8") for frame in frames])
--> 180     raise TypeError(msg)

TypeError: Could not serialize object of type NeuralNetClassifier.
Traceback (most recent call last):
  File "/home/nfs/vjawa/skorch_dask_cuda/distributed/distributed/protocol/pickle.py", line 49, in dumps
    result = pickle.dumps(x, **dump_kwargs)
  File "/datasets/vjawa/miniconda3/envs/cuml-torch-skorch/lib/python3.8/site-packages/skorch/net.py", line 2048, in __getstate__
    torch.save(cuda_attrs, f)
  File "/datasets/vjawa/miniconda3/envs/cuml-torch-skorch/lib/python3.8/site-packages/torch/serialization.py", line 379, in save
    _save(obj, opened_zipfile, pickle_module, pickle_protocol)
  File "/datasets/vjawa/miniconda3/envs/cuml-torch-skorch/lib/python3.8/site-packages/torch/serialization.py", line 484, in _save
    pickler.dump(obj)
_pickle.PicklingError: Can't pickle <class '__main__.MyModule'>: attribute lookup MyModule on __main__ failed

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/nfs/vjawa/skorch_dask_cuda/distributed/distributed/protocol/serialize.py", line 340, in serialize
    header, frames = dumps(x, context=context) if wants_context else dumps(x)
  File "/home/nfs/vjawa/skorch_dask_cuda/distributed/distributed/protocol/serialize.py", line 63, in pickle_dumps
    frames[0] = pickle.dumps(
  File "/home/nfs/vjawa/skorch_dask_cuda/distributed/distributed/protocol/pickle.py", line 60, in dumps
    result = cloudpickle.dumps(x, **dump_kwargs)
  File "/datasets/vjawa/miniconda3/envs/cuml-torch-skorch/lib/python3.8/site-packages/cloudpickle/cloudpickle_fast.py", line 73, in dumps
    cp.dump(obj)
  File "/datasets/vjawa/miniconda3/envs/cuml-torch-skorch/lib/python3.8/site-packages/cloudpickle/cloudpickle_fast.py", line 602, in dump
    return Pickler.dump(self, obj)
  File "/datasets/vjawa/miniconda3/envs/cuml-torch-skorch/lib/python3.8/site-packages/skorch/net.py", line 2048, in __getstate__
    torch.save(cuda_attrs, f)
  File "/datasets/vjawa/miniconda3/envs/cuml-torch-skorch/lib/python3.8/site-packages/torch/serialization.py", line 379, in save
    _save(obj, opened_zipfile, pickle_module, pickle_protocol)
  File "/datasets/vjawa/miniconda3/envs/cuml-torch-skorch/lib/python3.8/site-packages/torch/serialization.py", line 484, in _save
    pickler.dump(obj)
_pickle.PicklingError: Can't pickle <class '__main__.MyModule'>: attribute lookup MyModule on __main__ failed

Output on PR:

{'tcp://127.0.0.1:39677': <class 'skorch.classifier.NeuralNetClassifier'>[initialized](
  module_=MyModule(
    (dense0): Linear(in_features=20, out_features=10, bias=True)
  ),
), 'tcp://127.0.0.1:43891': <class 'skorch.classifier.NeuralNetClassifier'>[initialized](
  module_=MyModule(
    (dense0): Linear(in_features=20, out_features=10, bias=True)
  ),
)}

@github-actions
Copy link
Contributor

github-actions bot commented Feb 2, 2022

Unit Test Results

       18 files  ±    0         18 suites  ±0   9h 41m 7s ⏱️ + 2m 18s
  2 596 tests +  32    2 510 ✔️ +  29       82 💤 +    2  3 ±0  1 🔥 +1 
23 239 runs  +276  21 724 ✔️ +467  1 510 💤  - 193  4 +1  1 🔥 +1 

For more details on these failures and errors, see this check.

Results for commit 123ca52. ± Comparison against base commit 3902429.

♻️ This comment has been updated with latest results.

Copy link
Member

@jakirkham jakirkham left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Vibhu! 😄

Have some suggestions below to support zero-copy pickling (assuming skorch uses objects that can take advantage of that).

Also have some questions about module and the special handling there.

distributed/protocol/skorch.py Outdated Show resolved Hide resolved
distributed/protocol/skorch.py Outdated Show resolved Hide resolved
distributed/protocol/skorch.py Outdated Show resolved Hide resolved
distributed/protocol/skorch.py Outdated Show resolved Hide resolved
distributed/protocol/skorch.py Outdated Show resolved Hide resolved
distributed/protocol/skorch.py Outdated Show resolved Hide resolved
distributed/protocol/skorch.py Outdated Show resolved Hide resolved
has_module = hasattr(x, "module_")
headers = {"has_module": has_module}
if has_module:
module = x.__dict__.pop("module_")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious why module_ can't be pickled on its own. Is there any more info on the issues encountered by leaving this?

Also any downside to (temporarily) modifying a user-provided object here?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious why module_ can't be pickled on its own. Is there any more info on the issues encountered by leaving this?

So module's is an interactively defined class on client so its namespace is often __main__ . For eg. __main__. MyModule.

Pickle has problems pickling when interactively defined classes when they are set as an attributes of another object. as it tries to look up the class in the namespace. See eg for trace.
By pickling it on its own we are able to serialize successfully

Also any downside to (temporarily) modifying a user-provided object here?

The only side effect i can think of is if the class is redefined in the worker's name-space causing undefined behavior while de-serializing on the worker. I doubt that will really happen in real workflows.

FWIW, I have added a test to verify that at-least for the class is the same after deserialization.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this just an issue with pickle? Does cloudpickle run into this issue or does it work ok?

@VibhuJawa
Copy link
Author

@jakirkham ,

Thanks a lot for the comprehensive review. Helped me learn a lot about no-copy pickling and protocol 5. Very informative and useful.

Sorry for the delay in updating the PR , got side tracked by other stuff. I think its ready for another review. Have tried my best to address your reviews. Feel free to raise any more changes/points of clarifications.

Comment on lines +39 to +40
deserialized_net = list(client.run(test_serialize_skorch, net).values())[0]
assert isinstance(deserialized_net.module_, MyModule)
Copy link
Author

@VibhuJawa VibhuJawa Feb 9, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was the best test i could come up with for testing on the worker . Please let me know is there is a better way to check serialization in a different process.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we would benefit from a roundtrip serialization test like some of the others in this directory (without a cluster) to make sure that is working as expected. Know that doesn't show the error per-se, but it will help catch other errors in the future

In terms of testing a worker, would take a look at some of the other tests. Maybe like this one? Then adapt that to your use case. We shouldn't need to manually do the serialization ourselves, but instead rely on Dask to do that for us and merely check that things work as expected

@jakirkham
Copy link
Member

Thanks Vibhu! 😄

Mainly just have questions on the existing threads above

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants