From 52e5d0e55da675af38b5c3b95d6357b790c7030d Mon Sep 17 00:00:00 2001 From: Muammar El Khatib Date: Tue, 23 Apr 2019 10:47:44 -0700 Subject: [PATCH 1/3] Fix `Failed to Serialize` error with pytorch tensors. - When a tensor requires_grad then we have to t.detach().numpy() otherwise a .numpy() is used. This fixes the failed to serialized problem present in latest distributed version. - Improved test_grad() test as suggested by @stsievert. - The whole PR is included in a single commit. --- distributed/protocol/tests/test_torch.py | 10 ++++++---- distributed/protocol/torch.py | 7 ++++++- 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/distributed/protocol/tests/test_torch.py b/distributed/protocol/tests/test_torch.py index 6cc8bb20986..0f684ce7aaa 100644 --- a/distributed/protocol/tests/test_torch.py +++ b/distributed/protocol/tests/test_torch.py @@ -14,15 +14,17 @@ def test_tensor(): assert (x == t2.numpy()).all() -def test_grad(): +@pytest.mark.parametrize("requires_grad", [True, False]) +def test_grad(requires_grad): x = np.arange(10) - t = torch.Tensor(x) + t = torch.tensor(x, dtype=torch.float, requires_grad=requires_grad) t.grad = torch.zeros_like(t) + 1 t2 = deserialize(*serialize(t)) - assert (t2.numpy() == x).all() - assert (t2.grad.numpy() == 1).all() + assert (t2.detach().numpy() == x).all() + assert (t2.grad.numpy() == 1).all() + assert (t2.requires_grad is requires_grad) def test_resnet(): torchvision = pytest.importorskip("torchvision") diff --git a/distributed/protocol/torch.py b/distributed/protocol/torch.py index e69be68b0c1..3b4c6d19c8d 100644 --- a/distributed/protocol/torch.py +++ b/distributed/protocol/torch.py @@ -7,7 +7,12 @@ @dask_serialize.register(torch.Tensor) def serialize_torch_Tensor(t): requires_grad_ = t.requires_grad - header, frames = serialize(t.detach_().numpy()) + + if requires_grad_: + header, frames = serialize(t.detach().numpy()) + else: + header, frames = serialize(t.numpy()) + if t.grad is not None: grad_header, grad_frames = serialize(t.grad.numpy()) header["grad"] = {"header": grad_header, "start": len(frames)} From 18963d91e4fab14f91e918c005308c3e81c7545c Mon Sep 17 00:00:00 2001 From: Muammar El Khatib Date: Tue, 7 May 2019 18:43:51 -0700 Subject: [PATCH 2/3] More improvements to test_torch - Verify that t.requires_grad is not modified by serialization. - Use `np.allclose()` instead of `==`. --- distributed/protocol/tests/test_torch.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/distributed/protocol/tests/test_torch.py b/distributed/protocol/tests/test_torch.py index 0f684ce7aaa..eceb94ce0e0 100644 --- a/distributed/protocol/tests/test_torch.py +++ b/distributed/protocol/tests/test_torch.py @@ -13,18 +13,22 @@ def test_tensor(): t2 = deserialize(header, frames) assert (x == t2.numpy()).all() - @pytest.mark.parametrize("requires_grad", [True, False]) def test_grad(requires_grad): x = np.arange(10) t = torch.tensor(x, dtype=torch.float, requires_grad=requires_grad) - t.grad = torch.zeros_like(t) + 1 + + if requires_grad: + t.grad = torch.zeros_like(t) + 1 t2 = deserialize(*serialize(t)) - assert (t2.detach().numpy() == x).all() - assert (t2.grad.numpy() == 1).all() assert (t2.requires_grad is requires_grad) + assert (t.requires_grad is requires_grad) + assert np.allclose(t2.detach().numpy(), x) + + if requires_grad: + assert np.allclose(t2.grad.numpy(), 1) def test_resnet(): torchvision = pytest.importorskip("torchvision") From fffb68b5f921861badf1563fb462985a1569f3f7 Mon Sep 17 00:00:00 2001 From: Muammar El Khatib Date: Fri, 10 May 2019 08:52:43 -0700 Subject: [PATCH 3/3] flake8 and black cleaning. --- distributed/protocol/tests/test_torch.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/distributed/protocol/tests/test_torch.py b/distributed/protocol/tests/test_torch.py index eceb94ce0e0..efb5fa6610a 100644 --- a/distributed/protocol/tests/test_torch.py +++ b/distributed/protocol/tests/test_torch.py @@ -13,6 +13,7 @@ def test_tensor(): t2 = deserialize(header, frames) assert (x == t2.numpy()).all() + @pytest.mark.parametrize("requires_grad", [True, False]) def test_grad(requires_grad): x = np.arange(10) @@ -23,13 +24,14 @@ def test_grad(requires_grad): t2 = deserialize(*serialize(t)) - assert (t2.requires_grad is requires_grad) - assert (t.requires_grad is requires_grad) + assert t2.requires_grad is requires_grad + assert t.requires_grad is requires_grad assert np.allclose(t2.detach().numpy(), x) if requires_grad: assert np.allclose(t2.grad.numpy(), 1) + def test_resnet(): torchvision = pytest.importorskip("torchvision") model = torchvision.models.resnet.resnet18()