In [1]:
import torch as th
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import syft as sy

hook = sy.TorchHook(th)
bob = sy.VirtualWorker(hook, id="bob")
alice = sy.VirtualWorker(hook, id="alice")
charlie = sy.VirtualWorker(hook, id="charlie")

torch = th
syft = sy

me = hook.local_worker
hook.local_worker.is_client_worker = False

Local graph of computation

In [2]:
x = torch.tensor([1., 2., 3, 4, 5], requires_grad=True)
y = (x + x)
me.register_obj(y) # registration on the local worker is sometimes buggy

In [3]:
me._objects

{30324271312: tensor([1., 2., 3., 4., 5.], requires_grad=True),
 47299484475: tensor([ 2.,  4.,  6.,  8., 10.], grad_fn=<AddBackward0>)}

In [4]:
y_ptr = y.send(alice)

sender me to alice
Setting Gradient on 49886699849 me
Setting Gradient on 49886699849 me


Remote graph of comput + call to backward remotely

In [5]:
z = y_ptr.sum()
z.backward()

# Check that the gradient have been computed remotely: yes!
print(alice._objects[y_ptr.id_at_location].grad)

tensor([1., 1., 1., 1., 1.])


Local graph is not backwarded :(

In [6]:
x.grad, y.grad

(None, None)

Let's trigger manually to call back for backward

In [7]:
remote_y = alice._objects[y_ptr.id_at_location]
print(remote_y.origin)

{'sender': 'me', 'origin_id': 47299484475}


In [8]:
remote_y.trigger_origin_backward()

<VirtualWorker id:me #objects:4>
origin_ptr [PointerTensor | alice:47299484475 -> me:47299484475]
Setting Gradient on 47299484475 me


Yeah! Local graph has been updated!

In [9]:
x.grad

tensor([2., 2., 2., 2., 2.])

In [10]:
y.grad

tensor([1., 1., 1., 1., 1.])