In [1]:
import torch
import torch.nn as nn
import torchvision
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import pandas as pd
from torch.utils.data import Dataset, DataLoader
import os
from PIL import Image
import numpy as np
import PIL

In [2]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 20, 5, 1)
        self.conv2 = nn.Conv2d(20, 50, 5, 1)
        self.fc1 = nn.Linear(4*4*50, 500)
        self.fc2 = nn.Linear(500, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2, 2)
        x = x.view(-1, 4*4*50)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

In [3]:
def weights_to_array(model):
    '''
        input: pytorch model
        output: array of tensors of model weights
    '''
    model_weights = []
    for param in model.parameters():
        model_weights.append(param) # check without .data
    return model_weights

In [4]:
def g(z, node_weights, node_samples, total_no_samples):
    '''
        Optimizer loss function for geometric median
        Refer equation 3 in Krishna Pillutla et al., Robust Aggregation for Federated Learning
        
        input:  z - aggregator weights to minimize
                node_weights - array of model weights from weights_to_array function
                node_samples - array of sample counts from each node
                total_no_samples - sum(node_samples)
        output: weighted summation of euclidean norm with respect to the aggregator weights        
    '''
    summation = 0.0
    for layer_idx in range(len(node_weights[0])):
        temp = torch.zeros(node_weights[0][layer_idx].shape)
        for node_idx in range(len(node_weights)):
            euclidean_norm = (z[layer_idx] - node_weights[node_idx][layer_idx])**2
            weight_alpha = node_samples[node_idx]/total_no_samples
            temp = temp + (weight_alpha * euclidean_norm)
        summation = sum(temp)
    return summation

# g(weights_to_array(agg_model), [weights_to_array(model1), weights_to_array(model2)], [5, 5], 10)

In [16]:
model1 = Net()
model2 = Net()
agg_model = Net()

In [7]:
def optimizeGM(agg_model, args):
    optimizer = optim.Adam(agg_model.parameters(), lr=0.001)
    for _ in range(iterations):
        optimizer.zero_grad()
        loss = g(weights_to_array(agg_model), [weights_to_array(model1), weights_to_array(model1)], [5, 5], 10)
        print(loss)
        loss.backward()
        optimizer.step()
    return agg_model

In [19]:
for param in agg_model.parameters():
    print(param.data)
    param.data.requires_grad = True
    print(param.data.requires_grad)
    break

tensor([[[[-1.2504e-01,  7.7111e-02,  1.3537e-01, -1.1657e-01,  1.7329e-01],
          [-5.1712e-02,  2.1248e-02,  1.7058e-01,  4.8372e-02,  1.7051e-01],
          [ 1.2301e-01,  4.8677e-02,  1.6425e-01,  1.9989e-01, -5.1019e-02],
          [ 1.6017e-01, -1.6670e-01, -3.3064e-02, -1.9235e-01, -6.3099e-02],
          [ 4.9129e-02,  6.2872e-02, -9.2443e-02, -1.2771e-01,  1.0783e-01]]],


        [[[ 1.2969e-01,  1.9795e-01, -1.5388e-01,  8.6900e-02,  1.8385e-01],
          [-1.6360e-01, -7.3969e-04,  1.2249e-01, -8.6362e-02,  6.1612e-02],
          [-7.8150e-02, -1.9278e-01, -8.0221e-02,  1.8125e-01,  8.8826e-03],
          [ 7.6412e-02, -1.5301e-01,  1.7589e-01,  6.7143e-02,  8.6784e-02],
          [ 1.9356e-01, -3.1230e-02,  1.4794e-01, -5.5304e-02,  8.1687e-02]]],


        [[[ 1.7238e-01,  7.2237e-02, -9.4571e-02, -4.5212e-02, -1.3148e-01],
          [ 1.7791e-01, -7.2837e-02, -7.9420e-02, -1.1281e-01, -1.4283e-01],
          [-1.5918e-01,  1.1719e-01, -1.6291e-01, -1.5746e-01, -8.22

In [21]:
iterations = 100
weights = []
for param in agg_model.parameters():
    weights.append(param)
optimizer = optim.Adam(weights, lr=0.001)
a = []
for i in agg_model.parameters():
    a.append(i.clone())
for _ in range(iterations):
    optimizer.zero_grad()
    loss = g(weights_to_array(agg_model), [weights_to_array(model1), weights_to_array(model1)], [5, 5], 10)
    print(loss)
    loss.backward()
    optimizer.step()
b = []
for i in agg_model.parameters():
    b.append(i.clone())
for i in range(len(a)):
    print(torch.equal(a[i].data, b[i].data))


tensor(2.8004e-05, grad_fn=<AddBackward0>)
tensor(2.6082e-05, grad_fn=<AddBackward0>)
tensor(1.1383e-05, grad_fn=<AddBackward0>)
tensor(7.0989e-06, grad_fn=<AddBackward0>)
tensor(6.3351e-06, grad_fn=<AddBackward0>)
tensor(3.3580e-06, grad_fn=<AddBackward0>)
tensor(7.7103e-07, grad_fn=<AddBackward0>)
tensor(9.2439e-07, grad_fn=<AddBackward0>)
tensor(3.3211e-06, grad_fn=<AddBackward0>)
tensor(5.4769e-06, grad_fn=<AddBackward0>)
tensor(5.9227e-06, grad_fn=<AddBackward0>)
tensor(5.1705e-06, grad_fn=<AddBackward0>)
tensor(4.3712e-06, grad_fn=<AddBackward0>)
tensor(4.0471e-06, grad_fn=<AddBackward0>)
tensor(3.8088e-06, grad_fn=<AddBackward0>)
tensor(3.0791e-06, grad_fn=<AddBackward0>)
tensor(1.8714e-06, grad_fn=<AddBackward0>)
tensor(7.4469e-07, grad_fn=<AddBackward0>)
tensor(2.1300e-07, grad_fn=<AddBackward0>)
tensor(3.0815e-07, grad_fn=<AddBackward0>)
tensor(6.5271e-07, grad_fn=<AddBackward0>)
tensor(8.8309e-07, grad_fn=<AddBackward0>)
tensor(9.4358e-07, grad_fn=<AddBackward0>)
tensor(9.98

In [None]:
def Geometric_Median(model_data):
    '''
        input: array of tuples containing model 
        and the number of samples from each respective node
        output: geometric median aggregated model
    '''
    total_no_samples = 0
    
    # creates an array of weights and sample counts
    # shape -> no_models*no_layers*dim_of_layer
    node_weights = []
    node_samples = []
    for model,no_samples in model_data:
        node_weights.append(weights_to_array(model))
        node_samples.append(no_samples)
    # calculates the total number of samples
        total_no_samples += no_samples
    
    aggregated_weights = []
    for layer_idx in range(len(node_weights[0])):
        temp = torch.zeros(node_weights[0][layer_idx].shape)
        for node_idx in range(len(node_weights)):
            temp+= (node_samples[node_idx]/total_no_samples)*node_weights[node_idx][layer_idx]
        aggregated_weights.append(temp)
    agg_model = Net()
    for idx, param in enumerate(agg_model.parameters()):
        param.data = aggregated_weights[idx]
    return agg_model

In [41]:
model = Net()

In [43]:
params = list(model.parameters())

In [50]:
for p in params:
    print(p.shape)

torch.Size([20, 1, 5, 5])
torch.Size([20])
torch.Size([50, 20, 5, 5])
torch.Size([50])
torch.Size([500, 800])
torch.Size([500])
torch.Size([10, 500])
torch.Size([10])


In [74]:
sum(((params[0] - params[0]/2)**2).flatten())+0.0

tensor(1.6758, grad_fn=<AddBackward0>)

In [63]:
(1.9119e-01/2)**2

0.009138404025

In [75]:
l = [1,2,3,4,5,6,7]

In [90]:
t = torch.Tensor([[1,2,3,10,11],[1,2,50,90,100]])

In [92]:
torch.median(t,dim=-1)

torch.return_types.median(
values=tensor([ 3., 50.]),
indices=tensor([2, 2]))