In [1]:
import torch as th
import syft as sy  # <-- NEW: import the Pysyft library
hook = sy.TorchHook(th)  # <-- NEW: hook PyTorch ie add extra functionalities to support Federated Learning
bob = sy.VirtualWorker(hook, id="bob")  # <-- NEW: define remote worker bob
alice = sy.VirtualWorker(hook, id="alice")  # <-- NEW: and alice

In [2]:
x = th.tensor([1,2,3,4,5])

In [3]:
clients = [bob, alice]
x_ptr = x.send(*clients)

In [4]:
x_ptr.child.child[bob.id]

[PointerTensor | me:1711462308 -> bob:53085309076]

In [5]:
def get_avg_results(ptr):
    n_childs = len(ptr.child.child)
    return ptr.get(sum_results=True) / n_childs

In [6]:
get_avg_results(x_ptr)

tensor([1, 2, 3, 4, 5])

In [7]:
import torch.nn as nn
import torch.nn.functional as F

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 20, 5, 1)
        self.conv2 = nn.Conv2d(20, 50, 5, 1)
        self.fc1 = nn.Linear(4 * 4 * 50, 500)
        self.fc2 = nn.Linear(500, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2, 2)
        x = x.view(-1, 4 * 4 * 50)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

In [8]:
device = th.device("cuda" if th.cuda.is_available() else "cpu")
model = Net().to(device)

In [31]:
model.state_dict().keys()

odict_keys(['conv1.weight', 'conv1.bias', 'conv2.weight', 'conv2.bias', 'fc1.weight', 'fc1.bias', 'fc2.weight', 'fc2.bias'])

In [9]:
model2 = model.copy()
model3 = model.copy()

In [10]:
models = [model, model2, model3]

In [21]:
final_model = models[0].copy()
final_state_dict = {}

parameters = [model.state_dict() for model in models]
with th.no_grad():
    for parameter_name in parameters[0].keys():
        final_state_dict[parameter_name] = th.mean(th.stack([model_parameters[parameter_name] for model_parameters in parameters]), dim=0)
    final_model.load_state_dict(final_state_dict)

In [30]:
th.abs(model.state_dict()['conv1.weight'] - final_model.state_dict()['conv1.weight']) < 1e-5

tensor([[[[True, True, True, True, True],
          [True, True, True, True, True],
          [True, True, True, True, True],
          [True, True, True, True, True],
          [True, True, True, True, True]]],


        [[[True, True, True, True, True],
          [True, True, True, True, True],
          [True, True, True, True, True],
          [True, True, True, True, True],
          [True, True, True, True, True]]],


        [[[True, True, True, True, True],
          [True, True, True, True, True],
          [True, True, True, True, True],
          [True, True, True, True, True],
          [True, True, True, True, True]]],


        [[[True, True, True, True, True],
          [True, True, True, True, True],
          [True, True, True, True, True],
          [True, True, True, True, True],
          [True, True, True, True, True]]],


        [[[True, True, True, True, True],
          [True, True, True, True, True],
          [True, True, True, True, True],
          [True, T

In [25]:
final_model.state_dict()

OrderedDict([('conv1.weight',
              tensor([[[[ 1.3949e-01,  1.1792e-01, -1.4510e-01, -1.2963e-01, -1.5563e-01],
                        [ 9.7790e-03, -9.2998e-02,  1.5171e-01,  1.5338e-01, -1.2391e-01],
                        [ 1.0206e-01, -6.6261e-03,  1.3547e-01, -1.9324e-01, -9.0682e-02],
                        [ 4.1984e-02,  9.0107e-02, -1.6810e-02,  7.7416e-02, -4.3510e-02],
                        [-6.9508e-02,  1.2414e-01, -1.2825e-01,  8.1208e-02, -3.3040e-03]]],
              
              
                      [[[-7.0155e-02,  1.7508e-01,  6.2299e-02,  5.4426e-03, -1.8033e-01],
                        [ 6.0088e-03, -1.9241e-01, -1.4510e-01,  2.9196e-02,  1.0425e-01],
                        [ 1.2502e-01,  2.9769e-02, -1.7134e-01,  9.9133e-02,  8.9915e-02],
                        [ 1.5579e-01,  1.5782e-01, -1.8115e-01,  1.0880e-01,  1.2208e-01],
                        [-1.9508e-01,  1.5711e-02, -1.3522e-01, -1.4402e-01,  1.5913e-01]]],
              
           

In [20]:
th.mean(th.stack([model_parameters[parameter_name] for model_parameters in parameters]), dim=0)

torch.Size([20, 1, 5, 5])

In [None]:
model.state_dict()['conv1.weight']

In [None]:
model.train()
model.send(bob, alice)

In [None]:
dir(model)