In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import pandas as pd
from torch.utils.data import Dataset, DataLoader
import os
from PIL import Image
import numpy as np
import PIL

In [2]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 20, 5, 1)
        self.conv2 = nn.Conv2d(20, 50, 5, 1)
        self.fc1 = nn.Linear(4*4*50, 500)
        self.fc2 = nn.Linear(500, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2, 2)
        x = x.view(-1, 4*4*50)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

In [3]:
model1 = Net()
model2 = Net()
model3 = Net()
model4 = Net()

In [5]:
for i in model1.parameters():
    print(i.shape)

torch.Size([20, 1, 5, 5])
torch.Size([20])
torch.Size([50, 20, 5, 5])
torch.Size([50])
torch.Size([500, 800])
torch.Size([500])
torch.Size([10, 500])
torch.Size([10])


In [77]:
def weights_to_array(model):
    '''
        input: pytorch model
        output: array of tensors of model weights
    '''
    model_weights = []
    for param in model.parameters():
        model_weights.append(param) # check without .data
    return model_weights

In [158]:
def fed_avg_aggregator(model_data):
    '''
        input: array of tuples containing model 
        and the number of samples from each respective node
        output: fed_avg aggregated model
    '''
    total_no_samples = 0
    
    # creates an array of weights and sample counts
    # shape -> no_models*no_layers*dim_of_layer
    node_weights = []
    node_samples = []
    for model,no_samples in model_data:
        node_weights.append(weights_to_array(model))
        node_samples.append(no_samples)
    # calculates the total number of samples
        total_no_samples += no_samples
    
    aggregated_weights = []
    for layer_idx in range(len(node_weights[0])):
        temp = torch.zeros(node_weights[0][layer_idx].shape)
        for node_idx in range(len(node_weights)):
            temp+= (node_samples[node_idx]/total_no_samples)*node_weights[node_idx][layer_idx]
        aggregated_weights.append(temp)
    agg_model = Net()
    for idx, param in enumerate(agg_model.parameters()):
        param.data = aggregated_weights[idx]
    return agg_model

In [159]:
x = fed_avg_aggregator([(model1,60),(model2,40)])

tensor([[[[ 0.1513, -0.1399,  0.0211,  0.0361, -0.0181],
          [-0.0369,  0.0949,  0.1740,  0.1138, -0.1712],
          [-0.0825,  0.1080,  0.0870,  0.0283,  0.0249],
          [ 0.1040, -0.1336,  0.0483,  0.0348, -0.0379],
          [-0.0567,  0.0311, -0.1216,  0.0106,  0.0179]]],


        [[[ 0.0662,  0.0096,  0.0793, -0.0322, -0.0290],
          [-0.1006, -0.0307,  0.0354, -0.1119, -0.0379],
          [-0.0556,  0.1235, -0.0399,  0.1322, -0.1662],
          [ 0.1450, -0.1074, -0.0138,  0.0186,  0.0522],
          [-0.0330, -0.0985,  0.0467, -0.0484, -0.0887]]],


        [[[ 0.1198, -0.0691,  0.0575, -0.0543, -0.1341],
          [ 0.0373, -0.0113,  0.0674,  0.0892,  0.0156],
          [ 0.0556,  0.0005,  0.1078, -0.0512,  0.0843],
          [-0.1481, -0.0588, -0.0367,  0.0569,  0.0558],
          [-0.0490,  0.0772, -0.1342, -0.0407,  0.0893]]],


        [[[-0.0718,  0.0851,  0.0601, -0.1549, -0.1108],
          [-0.0817,  0.0584,  0.0724,  0.0734, -0.0604],
          [ 0.0794,

In [160]:
weights_to_array(model1)[0][0][0][0][0]*0.6+weights_to_array(model2)[0][0][0][0][0]*0.4

tensor(0.1513, grad_fn=<AddBackward0>)

In [161]:
x

Net(
  (conv1): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(20, 50, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=800, out_features=500, bias=True)
  (fc2): Linear(in_features=500, out_features=10, bias=True)
)

In [162]:
for i in x.parameters():
    print(i)

Parameter containing:
tensor([[[[ 0.1513, -0.1399,  0.0211,  0.0361, -0.0181],
          [-0.0369,  0.0949,  0.1740,  0.1138, -0.1712],
          [-0.0825,  0.1080,  0.0870,  0.0283,  0.0249],
          [ 0.1040, -0.1336,  0.0483,  0.0348, -0.0379],
          [-0.0567,  0.0311, -0.1216,  0.0106,  0.0179]]],


        [[[ 0.0662,  0.0096,  0.0793, -0.0322, -0.0290],
          [-0.1006, -0.0307,  0.0354, -0.1119, -0.0379],
          [-0.0556,  0.1235, -0.0399,  0.1322, -0.1662],
          [ 0.1450, -0.1074, -0.0138,  0.0186,  0.0522],
          [-0.0330, -0.0985,  0.0467, -0.0484, -0.0887]]],


        [[[ 0.1198, -0.0691,  0.0575, -0.0543, -0.1341],
          [ 0.0373, -0.0113,  0.0674,  0.0892,  0.0156],
          [ 0.0556,  0.0005,  0.1078, -0.0512,  0.0843],
          [-0.1481, -0.0588, -0.0367,  0.0569,  0.0558],
          [-0.0490,  0.0772, -0.1342, -0.0407,  0.0893]]],


        [[[-0.0718,  0.0851,  0.0601, -0.1549, -0.1108],
          [-0.0817,  0.0584,  0.0724,  0.0734, -0.0604