In [117]:
%load_ext autoreload
%autoreload 2

import os, sys
import copy
import socket
from tqdm import tqdm
import torch
import pickle
from torch import optim
from pathlib import Path
from torch.utils.tensorboard import SummaryWriter

sys.path.insert(0, os.path.abspath(os.path.join(os.getcwd(), "../../")))
from libs import fl, nn, agg, data, poison, log, sim

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [118]:
class FedArgs():
    def __init__(self):
        self.num_clients = 50
        self.epochs = 10
        self.local_rounds = 1
        self.client_batch_size = 32
        self.test_batch_size = 128
        self.learning_rate = 1e-4
        self.weight_decay = 1e-5
        self.cuda = False
        self.seed = 1
        self.tb = SummaryWriter('../../out/runs/federated/FLTrust', comment="Mnist Centralized Federated training")

fedargs = FedArgs()

In [119]:
use_cuda = fedargs.cuda and torch.cuda.is_available()
torch.manual_seed(fedargs.seed)
device = torch.device("cuda" if use_cuda else "cpu")
kwargs = {"num_workers": 1, "pin_memory": True} if use_cuda else {}

In [120]:
host = socket.gethostname()
clients = [host + "(" + str(client + 1) + ")" for client in range(fedargs.num_clients)]

In [121]:
#Initialize Global and Client models
global_model = nn.ModelMNIST()
client_models = {client: copy.deepcopy(global_model) for client in clients}

# Function for training
def train_model(model, train_loader, fedargs):
    model, loss = fl.client_update(model,
                                train_loader,
                                fedargs.learning_rate,
                                fedargs.weight_decay,
                                fedargs.local_rounds,
                                device)
    return model, loss

In [122]:
# Load MNIST Data to clients
train_data, test_data = data.load_dataset("mnist")

In [123]:
# For FLTrust
#############Skip this section for running other averaging
FLTrust = True
root_ratio = 0.01
train_data, root_data = torch.utils.data.random_split(train_data, [int(len(train_data) * (1-root_ratio)), 
                                                              int(len(train_data) * root_ratio)])
root_loader = torch.utils.data.DataLoader(root_data, batch_size=fedargs.client_batch_size, shuffle=True, **kwargs)

#global_model, _ = train_model(global_model, root_loader, fedargs)
#client_models = {client: copy.deepcopy(global_model) for client in clients}
#############

In [124]:
clients_data = data.split_data(train_data, clients)

In [125]:
# Poison a client
################Skip this section for running without poison
for client in range(10):
    clients_data[clients[client]] = poison.label_flip(clients_data[clients[client]], 4, 9, poison_percent = -1)
    
#clients_data[clients[0]] = poison.label_flip(clients_data[clients[0]], 6, 2, poison_percent = 1)
#clients_data[clients[0]] = poison.label_flip(clients_data[clients[0]], 3, 8, poison_percent = 1)
#clients_data[clients[0]] = poison.label_flip(clients_data[clients[0]], 1, 5, poison_percent = 1)

In [126]:
client_train_loaders, _ = data.load_client_data(clients_data, fedargs.client_batch_size, None, **kwargs)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=fedargs.test_batch_size, shuffle=True, **kwargs)

clients_info = {
        client: {"train_loader": client_train_loaders[client]}
        for client in clients
    }

In [127]:
def get_model_diff(base, model):
    params1 = model.state_dict().copy()
    params2 = base.state_dict().copy()
    with torch.no_grad():
        for name1 in params1:
            if name1 in params2:
                params1[name1] = params1[name1] - params2[name1]
    _model = copy.deepcopy(model)
    _model.load_state_dict(params1, strict=False)
    return _model

base_model = nn.ModelMNIST()
_base_model = copy.deepcopy(base_model)

_A1 = copy.deepcopy(base_model)
_A2 = copy.deepcopy(base_model)

_C3 = copy.deepcopy(base_model)
_C4 = copy.deepcopy(base_model)
_C5 = copy.deepcopy(base_model)

for i in range(3):
    _base_model, _ = train_model(_base_model, root_loader, fedargs)
    _A1, _ = train_model(_A1, clients_info[clients[0]]['train_loader'], fedargs)
    _A2, _ = train_model(_A2, clients_info[clients[1]]['train_loader'], fedargs)

    _C3, _ = train_model(_C3, clients_info[clients[20]]['train_loader'], fedargs)
    _C4, _ = train_model(_C4, clients_info[clients[21]]['train_loader'], fedargs)
    _C5, _ = train_model(_C4, clients_info[clients[22]]['train_loader'], fedargs)
    
    client_models = {0: _A1, 1: _A2, 2: _C3, 3: _C4, 4: _C5}
    
    #global_model = fl.federated_avg(client_models, _base_model, agg.Rule.FLTrust)
    
_base_model = get_model_diff(base_model, _base_model)
_A1 = get_model_diff(base_model, _A1)
_A2 = get_model_diff(base_model, _A2)
_C3 = get_model_diff(base_model, _C3)
_C4 = get_model_diff(base_model, _C4)
_C5 = get_model_diff(base_model, _C5)

In [130]:
print(fl.eval(global_model, test_loader, device, 4, 9)['accuracy'])

print("###############")

print(sim.grad_cosine_similarity(_base_model, _A1))
print(fl.eval(_A1, test_loader, device, 4, 9)['accuracy'])

print("###############")

print(sim.grad_cosine_similarity(_base_model, _A2))
print(fl.eval(_A2, test_loader, device, 4, 9)['accuracy'])

print("###############")

print(sim.grad_cosine_similarity(_base_model, _C3))
print(fl.eval(_C3, test_loader, device, 4, 9)['accuracy'])

print("###############")

print(sim.grad_cosine_similarity(_base_model, _C4))
print(fl.eval(_C4, test_loader, device, 4, 9)['accuracy'])

print("###############")

print(sim.grad_cosine_similarity(_base_model, _C5))
print(fl.eval(_C5, test_loader, device, 4, 9)['accuracy'])

print("###############")

print(sim.grad_cosine_similarity(nn.ModelMNIST(), nn.ModelMNIST()))

10.66
###############
0.3191995
10.09
###############
0.3230263
10.09
###############
0.5237853
32.31
###############
0.5836278
30.080000000000002
###############
0.563691
39.910000000000004
###############
0.007604853


In [131]:
print(sim.grad_norm(_base_model), sim.grad_norm(_A1), sim.grad_norm(_A2), sim.grad_norm(_C3), sim.grad_norm(_C4), sim.grad_norm(_C5))

2.2304773 0.9644073 0.9594911 3.625153 3.5358863 4.0946817
