In [424]:
import torch
import syft as sy  # import the Pysyft library
hook = sy.TorchHook(torch)  # hook PyTorch to add extra functionalities like Federated and Encrypted Learning

class Worker:
    def __init__(self, _id, data, target):
        self.worker = sy.VirtualWorker(hook, id=_id)
        self.data = data.send(_id)
        self.target = target.send(_id)



In [425]:
alice_data = torch.tensor([[0,0],[1,1.]], requires_grad=True).tag('#boundary', '#toy')
alice_target = torch.tensor([[0],[1.]], requires_grad=True).tag('#boundary_target', '#toy_target')

alice = Worker('alice', alice_data, alice_target)

In [426]:
bob_data = torch.tensor([[0,1],[1,0]]).tag('#middle', '#toy')
bob_target = torch.tensor([[0],[1]]).tag('#middle_target', '#toy_target')

bob = Worker('bob', bob_data, bob_target)

In [427]:
node_data = torch.tensor([[0,0],[1,0]]).tag('#alt', '#toy')
node_target = torch.tensor([[0],[1]]).tag('#alt_target', '#toy_target')

node = Worker('node', node_data, node_target)

In [439]:
check_alice_for_toy = alice.worker.search(["#toy"])
#check_alice_for_toy

In [440]:
check_bob_for_toy = bob.worker.search(["#toy"])

In [430]:
import copy
from torch import nn, optim

In [441]:
# Iniitalize A Toy Model
model = nn.Linear(2,1)
#nodes_model = nn.Linear(2,1)

In [442]:
nodes_model = model.copy().send(node.worker)
nodes_opt = optim.SGD(params=node_model.parameters(),lr=0.1)
preds = nodes_model(node.data.float())
loss = ((preds - node.target.float()) ** 2).sum()
print(loss.get().data)

tensor(0.6060)


In [433]:
iterations = 10
worker_iters = 5

for a_iter in range(iterations):
    
    bobs_model = model.copy().send(bob.worker)
    alices_model = model.copy().send(alice.worker)

    bobs_opt = optim.SGD(params=bobs_model.parameters(),lr=0.1)
    alices_opt = optim.SGD(params=alices_model.parameters(),lr=0.1)

    for wi in range(worker_iters):

        # Train Bob's Model
        bobs_opt.zero_grad()
        bobs_pred = bobs_model(bob.data.float())
        bobs_loss = ((bobs_pred - bob.target.float())**2).sum()
        bobs_loss.backward()

        bobs_opt.step()
        bobs_loss = bobs_loss.get().data

        # Train Alice's Model
        alices_opt.zero_grad()
        alices_pred = alices_model(alice.data)
        alices_loss = ((alices_pred - alice.target)**2).sum()
        alices_loss.backward()

        alices_opt.step()
        alices_loss = alices_loss.get().data
    
    alices_model.move(node.worker)
    bobs_model.move(node.worker)
    
    with torch.no_grad():
        model.weight.set_(((alices_model.weight.data + bobs_model.weight.data) / 2).get())
        model.bias.set_(((alices_model.bias.data + bobs_model.bias.data) / 2).get())
    
    print("Bob:" + str(bobs_loss) + " Alice:" + str(alices_loss))

Bob:tensor(0.0193) Alice:tensor(0.0280)
Bob:tensor(0.0084) Alice:tensor(0.0136)
Bob:tensor(0.0037) Alice:tensor(0.0079)
Bob:tensor(0.0016) Alice:tensor(0.0046)
Bob:tensor(0.0007) Alice:tensor(0.0027)
Bob:tensor(0.0003) Alice:tensor(0.0016)
Bob:tensor(0.0001) Alice:tensor(0.0009)
Bob:tensor(6.1581e-05) Alice:tensor(0.0005)
Bob:tensor(2.7140e-05) Alice:tensor(0.0003)
Bob:tensor(1.1962e-05) Alice:tensor(0.0002)


In [434]:
nodes_model = model.copy().send(node.worker)
nodes_opt = optim.SGD(params=node_model.parameters(),lr=0.1)
preds = nodes_model(node.data.float())
loss = ((preds - node.target.float()) ** 2).sum()
print(loss.get().data)

tensor(0.0002)
