Catch RuntimeError to avoid serialization fail when using pytorch#2619
Catch RuntimeError to avoid serialization fail when using pytorch#2619mrocklin merged 5 commits intodask:masterfrom
Conversation
|
I'm curious, what caused the problem? Is it worth adding a small test? |
This is related to dask#2619. Note that I had to use .detach_().numpy() in the test because after de/serialization `distributed` correctly returns a tensor with requires_grad=True. When requires_grad=True .numpy() does not work and .detach_().numpy() has to be used instead.
When a torch tensor requires grad, then one can not use Proof: I undo the fix I submitted and just added this change: Running the test for the protocols fails with: Applying the change proposed in this PR makes everything work: Does it make sense? |
|
cc @stsievert if you have time to review |
| try: | ||
| header, frames = serialize(t.detach_().numpy()) | ||
| except RuntimeError: | ||
| header, frames = serialize(t.numpy()) |
There was a problem hiding this comment.
Why not an use an if-statement instead? It's known when the error will be thrown: when t.requires_grad is true.
There was a problem hiding this comment.
I changed this block to use an if-statement.
| header, frames = serialize(t.detach_().numpy()) | ||
|
|
||
| try: | ||
| header, frames = serialize(t.detach_().numpy()) |
There was a problem hiding this comment.
I don't think we want to use detach_ here because it modifies t in place:
>>> import torch
>>> x = torch.rand(4, requires_grad=True)
>>> y = x.detach_().numpy()
>>> x.requires_grad
FalseA test should be added to make sure that x.requires_grad stays the same before/after serialization.
There was a problem hiding this comment.
I originally placed a .detach() but after some discussions, we decided to add .detach_() instead. The reason you said, and it made sense for me, was that .detach_() creates a leaf and does not break the graph. What we want is just to extract the data from the tensor object, and if it requires or not gradients, to be de/serialized correctly. This operation shouldn't perturb what the model expects from these tensors. Are we sure we have to get back to a plain .detach().
This would be the commit to address what's been reviewed here:
diff --git a/distributed/protocol/tests/test_torch.py b/distributed/protocol/tests/test_torch.py
index e6747e32..fa181a0d 100644
--- a/distributed/protocol/tests/test_torch.py
+++ b/distributed/protocol/tests/test_torch.py
@@ -21,8 +21,17 @@ def test_grad():
t.requires_grad = True
t2 = deserialize(*serialize(t))
- assert (t2.detach_().numpy() == x).all()
+
+ assert (t2.detach().numpy() == x).all()
assert (t2.grad.numpy() == 1).all()
+ assert (t2.requires_grad is True)
+
+ t.requires_grad = False
+ t3 = deserialize(*serialize(t))
+
+ assert (t3.detach().numpy() == x).all()
+ assert (t3.grad.numpy() == 1).all()
+ assert (t3.requires_grad is False)
def test_resnet():
diff --git a/distributed/protocol/torch.py b/distributed/protocol/torch.py
index e4185468..a70e60b6 100644
--- a/distributed/protocol/torch.py
+++ b/distributed/protocol/torch.py
@@ -8,9 +8,9 @@ import numpy as np
def serialize_torch_Tensor(t):
requires_grad_ = t.requires_grad
- try:
+ if requires_grad_:
header, frames = serialize(t.detach_().numpy())
- except RuntimeError:
+ else:
header, frames = serialize(t.numpy())
if t.grad is not None:
Thoughts :)?
There was a problem hiding this comment.
.detach_() creates a leaf and does not break the graph.
I think this is true for detach() too. I think #2586 (comment) is concerned with correctly sending the requires_grad information.
Serialization should not implicitly modify the inputs. This would motivate the use of detach() over detach_().
There was a problem hiding this comment.
I changed it back to detach().
| x = np.arange(10) | ||
| t = torch.Tensor(x) | ||
| t.grad = torch.zeros_like(t) + 1 | ||
| t.requires_grad = True |
There was a problem hiding this comment.
It'd be nice if this test could have both values of requires_grad. Maybe pytest.mark.parametrize?
There was a problem hiding this comment.
I don't know how to do that :S I can investigate though.
There was a problem hiding this comment.
Something like
@pytest.mark.parametrize("requires_grad", [True, False])
def test_grad(requires_grad):
x = torch.zeros(..., requires_grad=requires_grad)There was a problem hiding this comment.
Very cool. I did not know about this. It is the first time I am using pytest. I have done something similar.
|
|
||
| t.requires_grad = False | ||
| t3 = deserialize(*serialize(t)) | ||
|
|
There was a problem hiding this comment.
I would test that serialization and de-serialization does not modify t.requires_grad. Adding this line below the serialization would test that:
assert t.requires_grad == requires_grad- When a tensor requires_grad then we have to t.detach().numpy() otherwise a .numpy() is used. This fixes the failed to serialized problem present in latest distributed version. - Improved test_grad() test as suggested by @stsievert. - The whole PR is included in a single commit.
|
I updated the PR following all suggestions by @stsievert :). I also unified all changes in one single commit. Please, let me know if some other changes are needed. Thanks. |
|
rerunning tests |
|
Checking in. What's the status here? @muammar are you blocked on anything? It looks like @stsievert has made a few recommendations above. |
- Verify that t.requires_grad is not modified by serialization. - Use `np.allclose()` instead of `==`.
* upstream/master: Add Type Attribute to TaskState (dask#2657) Add waiting task count to progress title bar (dask#2663) DOC: Clean up reference to cluster object (dask#2664) Allow scheduler to politely close workers as part of shutdown (dask#2651) Check direct_to_workers before using get_worker in Client (dask#2656) Fixed comment regarding keeping existing level if less verbose (dask#2655) Add idle timeout to scheduler (dask#2652) Avoid deprecation warnings (dask#2653) Use an LRU cache for deserialized functions (dask#2623) Rename Worker._close to Worker.close (dask#2650) Add Comm closed bookkeeping (dask#2648) Explain LocalCluster behavior in Client docstring (dask#2647) Add last worker into KilledWorker exception to help debug (dask#2610) Set working worker class for dask-ssh (dask#2646) Add as_completed methods to docs (dask#2642) Add timeout to Client._reconnect (dask#2639) Limit test_spill_by_default memory, reenable it (dask#2633) Use proper address in worker -> nanny comms (dask#2640)
Thanks for checking-in :). I am sorry I was not able to put this together before. The latest commit should (hopefully) comply with all requests by @stsievert. |
|
@muammar my concerns are resolved, especially the biggest one: now, the tests make sure the Looks like these changes are complaining about linting issues and missing a run through Black. (the other issues are unrelated and in |
Is there anything from my side I could do to help fix this? |
Definitely with the linting issue. Dask uses black and flake8 for style, so I'd run $ pip install black flake8
$ cd /path/to/distributed
$ black .
$ flake8 .The docs at https://docs.dask.org/en/latest/develop.html#style could be improved. |
Oook! I did not understand you were referring to code style. I can do that and probably submit a different PR for that purpose (?). |
Thanks!
I've put in #2680 for this. |
I have done it now. By the way, I had no idea of the existence of black. Very interesting. I will use it on my projects. |
|
The windows failure is unrelated (handling in #2684) Thanks for the contribution @muammar and thanks for the review @stsievert . Merging. |
* upstream/master: Add WeakSet _instances attributes to all classes (dask#2673) Cap worker's memory limit by the hard limit of the maximum resident memory (dask#2665) Switch from (ip, port) to address in tests (dask#2684) Catch RuntimeError to avoid serialization fail when using pytorch (dask#2619) Add CONTRIBUTING.md (dask#2680) Consolidate logic around services (dask#2679) Fix uri_from_host_port import in dask-mpi (dask#2683) Move dashboard_address logic into Scheduler/Worker (dask#2678) Fix pytest.config deprecation warning (dask#2677) Use config accessor method for "scheduler-address" (dask#2676) Add memory and disk aliases to Worker.data (dask#2670) Move interface/host/port handling from CLI to classes (dask#2667) Remove AioClient (dask#2668) Add release procedure doc (dask#2672) bump version to 1.28.0
After updating one of my systems I realized that there was a problem in
distributed/protocol/torch.pythat made serialization to fail. In this PR, the problem is fixed by adding a try/except statement to fix the problem.