Skip to content

Catch RuntimeError to avoid serialization fail when using pytorch#2619

Merged
mrocklin merged 5 commits intodask:masterfrom
muammar:master
May 10, 2019
Merged

Catch RuntimeError to avoid serialization fail when using pytorch#2619
mrocklin merged 5 commits intodask:masterfrom
muammar:master

Conversation

@muammar
Copy link
Copy Markdown
Contributor

@muammar muammar commented Apr 17, 2019

After updating one of my systems I realized that there was a problem in distributed/protocol/torch.py that made serialization to fail. In this PR, the problem is fixed by adding a try/except statement to fix the problem.

@mrocklin
Copy link
Copy Markdown
Member

I'm curious, what caused the problem? Is it worth adding a small test?

muammar added a commit to muammar/distributed that referenced this pull request Apr 18, 2019
This is related to dask#2619. Note that I had to use .detach_().numpy() in
the test because after de/serialization `distributed` correctly returns
a tensor with requires_grad=True. When requires_grad=True .numpy() does
not work and .detach_().numpy() has to be used instead.
@muammar
Copy link
Copy Markdown
Contributor Author

muammar commented Apr 18, 2019

I'm curious, what caused the problem? Is it worth adding a small test?

When a torch tensor requires grad, then one can not use .numpy() to convert it to a numpy array. Instead, one has to use .detach_(). In our example, the gradient data is set to some arbitrary values but the tensor does not have the .requires_grad attribute set to true.

Proof:

I undo the fix I submitted and just added this change:

diff --git a/distributed/protocol/tests/test_torch.py b/distributed/protocol/tests/test_torch.py
index 6cc8bb20..1e2190e8 100644
--- a/distributed/protocol/tests/test_torch.py
+++ b/distributed/protocol/tests/test_torch.py
@@ -17,6 +17,7 @@ def test_tensor():
 def test_grad():
     x = np.arange(10)
     t = torch.Tensor(x)
+    t.requires_grad = True
     t.grad = torch.zeros_like(t) + 1
 
     t2 = deserialize(*serialize(t))

Running the test for the protocols fails with:

distributed/protocol/tests/test_torch.py::test_grad FAILED                                                                                                                                                  [ 98%]
distributed/protocol/tests/test_torch.py::test_resnet PASSED                                                                                                                                                [ 99%]
distributed/protocol/tests/test_torch.py::test_deserialize_grad PASSED                                                                                                                                      [100%]

==================================================================================================== FAILURES =====================================================================================================
____________________________________________________________________________________________________ test_grad ____________________________________________________________________________________________________

    def test_grad():
        x = np.arange(10)
        t = torch.Tensor(x)
        t.requires_grad = True
        t.grad = torch.zeros_like(t) + 1
    
        t2 = deserialize(*serialize(t))
>       assert (t2.numpy() == x).all()
E       RuntimeError: Can't call numpy() on Variable that requires grad. Use var.detach().numpy() instead.

distributed/protocol/tests/test_torch.py:24: RuntimeError

Applying the change proposed in this PR makes everything work:

distributed/protocol/tests/test_sklearn.py::test_basic PASSED                                                                                                                                               [ 97%]
distributed/protocol/tests/test_torch.py::test_tensor PASSED                                                                                                                                                [ 97%]
distributed/protocol/tests/test_torch.py::test_grad PASSED                                                                                                                                                  [ 98%]
distributed/protocol/tests/test_torch.py::test_resnet PASSED                                                                                                                                                [ 99%]
distributed/protocol/tests/test_torch.py::test_deserialize_grad PASSED     

Does it make sense?

@mrocklin
Copy link
Copy Markdown
Member

cc @stsievert if you have time to review

try:
header, frames = serialize(t.detach_().numpy())
except RuntimeError:
header, frames = serialize(t.numpy())
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not an use an if-statement instead? It's known when the error will be thrown: when t.requires_grad is true.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed this block to use an if-statement.

Comment thread distributed/protocol/torch.py Outdated
header, frames = serialize(t.detach_().numpy())

try:
header, frames = serialize(t.detach_().numpy())
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we want to use detach_ here because it modifies t in place:

>>> import torch
>>> x = torch.rand(4, requires_grad=True)
>>> y = x.detach_().numpy()
>>> x.requires_grad
False

A test should be added to make sure that x.requires_grad stays the same before/after serialization.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I originally placed a .detach() but after some discussions, we decided to add .detach_() instead. The reason you said, and it made sense for me, was that .detach_() creates a leaf and does not break the graph. What we want is just to extract the data from the tensor object, and if it requires or not gradients, to be de/serialized correctly. This operation shouldn't perturb what the model expects from these tensors. Are we sure we have to get back to a plain .detach().

This would be the commit to address what's been reviewed here:

diff --git a/distributed/protocol/tests/test_torch.py b/distributed/protocol/tests/test_torch.py
index e6747e32..fa181a0d 100644
--- a/distributed/protocol/tests/test_torch.py
+++ b/distributed/protocol/tests/test_torch.py
@@ -21,8 +21,17 @@ def test_grad():
     t.requires_grad = True
 
     t2 = deserialize(*serialize(t))
-    assert (t2.detach_().numpy() == x).all()
+
+    assert (t2.detach().numpy() == x).all()
     assert (t2.grad.numpy() == 1).all()
+    assert (t2.requires_grad is True)
+
+    t.requires_grad = False
+    t3 = deserialize(*serialize(t))
+
+    assert (t3.detach().numpy() == x).all()
+    assert (t3.grad.numpy() == 1).all()
+    assert (t3.requires_grad is False)
 
 
 def test_resnet():
diff --git a/distributed/protocol/torch.py b/distributed/protocol/torch.py
index e4185468..a70e60b6 100644
--- a/distributed/protocol/torch.py
+++ b/distributed/protocol/torch.py
@@ -8,9 +8,9 @@ import numpy as np
 def serialize_torch_Tensor(t):
     requires_grad_ = t.requires_grad
 
-    try:
+    if requires_grad_:
         header, frames = serialize(t.detach_().numpy())
-    except RuntimeError:
+    else:
         header, frames = serialize(t.numpy())
 
     if t.grad is not None:

Thoughts :)?

Copy link
Copy Markdown
Member

@stsievert stsievert Apr 21, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

.detach_() creates a leaf and does not break the graph.

I think this is true for detach() too. I think #2586 (comment) is concerned with correctly sending the requires_grad information.

Serialization should not implicitly modify the inputs. This would motivate the use of detach() over detach_().

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed it back to detach().

x = np.arange(10)
t = torch.Tensor(x)
t.grad = torch.zeros_like(t) + 1
t.requires_grad = True
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It'd be nice if this test could have both values of requires_grad. Maybe pytest.mark.parametrize?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know how to do that :S I can investigate though.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Something like

@pytest.mark.parametrize("requires_grad", [True, False])
def test_grad(requires_grad):
    x = torch.zeros(..., requires_grad=requires_grad)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very cool. I did not know about this. It is the first time I am using pytest. I have done something similar.


t.requires_grad = False
t3 = deserialize(*serialize(t))

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would test that serialization and de-serialization does not modify t.requires_grad. Adding this line below the serialization would test that:

assert t.requires_grad == requires_grad

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, too.

- When a tensor requires_grad then we have to t.detach().numpy()
  otherwise a .numpy() is used. This fixes the failed to serialized
  problem present in latest distributed version.
- Improved test_grad() test as suggested by @stsievert.
- The whole PR is included in a single commit.
@muammar
Copy link
Copy Markdown
Contributor Author

muammar commented Apr 23, 2019

I updated the PR following all suggestions by @stsievert :). I also unified all changes in one single commit. Please, let me know if some other changes are needed. Thanks.

Copy link
Copy Markdown
Member

@stsievert stsievert left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the changes @muammar. I've got a couple nits. Past those it looks good.

Comment thread distributed/protocol/tests/test_torch.py Outdated
Comment thread distributed/protocol/tests/test_torch.py Outdated
* upstream/master:
  Fix deserialization of bytes chunks larger than 64MB (dask#2637)
  bump version to 1.27.1
  Updated logging module doc links from docs.python.org/2 to docs.python.org/3. (dask#2635)
  Adaptive: recommend close workers when any are idle (dask#2330)
@quasiben
Copy link
Copy Markdown
Member

rerunning tests

@mrocklin
Copy link
Copy Markdown
Member

mrocklin commented May 4, 2019

Checking in. What's the status here? @muammar are you blocked on anything? It looks like @stsievert has made a few recommendations above.

muammar added 2 commits May 7, 2019 18:43
- Verify that t.requires_grad is not modified by serialization.
- Use `np.allclose()` instead of `==`.
* upstream/master:
  Add Type Attribute to TaskState (dask#2657)
  Add waiting task count to progress title bar (dask#2663)
  DOC: Clean up reference to cluster object (dask#2664)
  Allow scheduler to politely close workers as part of shutdown (dask#2651)
  Check direct_to_workers before using get_worker in Client (dask#2656)
  Fixed comment regarding keeping existing level if less verbose (dask#2655)
  Add idle timeout to scheduler (dask#2652)
  Avoid deprecation warnings (dask#2653)
  Use an LRU cache for deserialized functions (dask#2623)
  Rename Worker._close to Worker.close (dask#2650)
  Add Comm closed bookkeeping (dask#2648)
  Explain LocalCluster behavior in Client docstring (dask#2647)
  Add last worker into KilledWorker exception to help debug  (dask#2610)
  Set working worker class for dask-ssh (dask#2646)
  Add as_completed methods to docs (dask#2642)
  Add timeout to Client._reconnect (dask#2639)
  Limit test_spill_by_default memory, reenable it (dask#2633)
  Use proper address in worker -> nanny comms (dask#2640)
@muammar
Copy link
Copy Markdown
Contributor Author

muammar commented May 8, 2019

Checking in. What's the status here? @muammar are you blocked on anything? It looks like @stsievert has made a few recommendations above.

Thanks for checking-in :). I am sorry I was not able to put this together before. The latest commit should (hopefully) comply with all requests by @stsievert.

@stsievert
Copy link
Copy Markdown
Member

@muammar my concerns are resolved, especially the biggest one: now, the tests make sure the requires_grad flag of the input array is unmodified.

Looks like these changes are complaining about linting issues and missing a run through Black. (the other issues are unrelated and in test_client.py::test_profile_bokeh and test_local.py::test_ipywidgets).

@muammar
Copy link
Copy Markdown
Contributor Author

muammar commented May 9, 2019

@muammar my concerns are resolved, especially the biggest one: now, the tests make sure the requires_grad flag of the input array is unmodified.

Looks like these changes are complaining about linting issues and missing a run through Black. (the other issues are unrelated and in test_client.py::test_profile_bokeh and test_local.py::test_ipywidgets).

Is there anything from my side I could do to help fix this?

@stsievert
Copy link
Copy Markdown
Member

Is there anything from my side I could do to help fix this?

Definitely with the linting issue. Dask uses black and flake8 for style, so I'd run

$ pip install black flake8
$ cd /path/to/distributed
$ black .
$ flake8 .

The docs at https://docs.dask.org/en/latest/develop.html#style could be improved.

@muammar
Copy link
Copy Markdown
Contributor Author

muammar commented May 9, 2019

Is there anything from my side I could do to help fix this?

Definitely with the linting issue. Dask uses black and flake8 for style, so I'd run

$ pip install black flake8
$ cd /path/to/distributed
$ black .
$ flake8 .

The docs at https://docs.dask.org/en/latest/develop.html#style could be improved.

Oook! I did not understand you were referring to code style. I can do that and probably submit a different PR for that purpose (?).

@stsievert
Copy link
Copy Markdown
Member

Oook! I did not understand you were referring to code style.

Thanks!

I can do that and probably submit a different PR for that purpose (?).

I've put in #2680 for this.

@muammar
Copy link
Copy Markdown
Contributor Author

muammar commented May 10, 2019

Oook! I did not understand you were referring to code style.

Thanks!

I can do that and probably submit a different PR for that purpose (?).

I've put in #2680 for this.

I have done it now. By the way, I had no idea of the existence of black. Very interesting. I will use it on my projects.

@mrocklin
Copy link
Copy Markdown
Member

The windows failure is unrelated (handling in #2684)

Thanks for the contribution @muammar and thanks for the review @stsievert .

Merging.

@mrocklin mrocklin merged commit a8fa4c1 into dask:master May 10, 2019
muammar added a commit to muammar/distributed that referenced this pull request Jul 18, 2019
* upstream/master:
  Add WeakSet _instances attributes to all classes (dask#2673)
  Cap worker's memory limit by the hard limit of the maximum resident memory (dask#2665)
  Switch from (ip, port) to address in tests (dask#2684)
  Catch RuntimeError to avoid serialization fail when using pytorch (dask#2619)
  Add CONTRIBUTING.md (dask#2680)
  Consolidate logic around services (dask#2679)
  Fix uri_from_host_port import in dask-mpi (dask#2683)
  Move dashboard_address logic into Scheduler/Worker (dask#2678)
  Fix pytest.config deprecation warning (dask#2677)
  Use config accessor method for "scheduler-address" (dask#2676)
  Add memory and disk aliases to Worker.data (dask#2670)
  Move interface/host/port handling from CLI to classes (dask#2667)
  Remove AioClient (dask#2668)
  Add release procedure doc (dask#2672)
  bump version to 1.28.0
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants