In [1]:
%load_ext autoreload
%autoreload 2

In [2]:




import torch
import numpy as np
from torch.utils.data import Dataset, DataLoader 
from abc import ABC
from tqdm import tqdm
import torchvision
import os
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import random




config = {}
config["seed"] = 42
seed = config["seed"]
os.environ['PYTHONHASHSEED'] = str(seed)
# Torch RNG
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
# Python RNG
np.random.seed(seed)
random.seed(seed)




config["num_clients_per_cluster"] = 8
config["num_clusters"] = 4
config["num_clients"] = config["num_clients_per_cluster"]*config["num_clusters"]
config["dataset"] = "mnist"
DATASET_LIB = {"mnist" : torchvision.datasets.MNIST, "emnist": torchvision.datasets.EMNIST, "cifar10": torchvision.datasets.CIFAR10}
config["dataset_dir"] = "./experiments/dataset"
config["results_dir"] = "./experiments/results"
config["results_dir"] = os.path.join(config["results_dir"], config["dataset"], "seed_{}".format(seed))

os.makedirs(config["results_dir"], exist_ok=True)

def split(dataset_size, num_clients, shuffle):
    split_idx = []
    all_idx = np.arange(dataset_size)
    if shuffle:
        all_idx = np.random.permutation(all_idx)
    split_idx = np.array_split(all_idx, num_clients)
    return split_idx

def dataset_split(train_data, test_data, num_clients, shuffle):
    train_size = train_data[0].shape[0]
    train_split_idx = split(train_size, num_clients, shuffle)
    train_chunks = [(train_data[0][train_split_idx[client]], train_data[1][train_split_idx[client]]) for client in range(num_clients)]
    test_size = test_data[0].shape[0]
    test_split_idx = split(test_size, num_clients, shuffle)
    test_chunks = [(test_data[0][test_split_idx[client]], test_data[1][test_split_idx[client]]) for client in range(num_clients)]
    return train_chunks, test_chunks

def make_client_datasets(config, num_clusters = 10):
    train_chunks_total = []
    test_chunks_total = []
    train_dataset = DATASET_LIB[config["dataset"]](root = config['dataset_dir'], download = True, train=True)
    test_dataset = DATASET_LIB[config["dataset"]](root = config['dataset_dir'], download = True, train=False)

    train_data = (train_dataset.data, train_dataset.targets)
    test_data = (test_dataset.data, test_dataset.targets)
    for i in range(config["num_clusters"]):
        train_chunks, test_chunks = dataset_split(train_data, test_data, config["num_clients_per_cluster"], shuffle=True)
        train_chunks_total += train_chunks
        test_chunks_total += test_chunks
    return train_chunks_total, test_chunks_total


class ClientDataset(Dataset):
    def __init__(self, data,transforms = None):
        super(ClientDataset,self).__init__()
        self.data = data[0]
        self.labels = data[1]
        self.transforms = transforms

    def __len__(self):
        return len(self.data)
    
    def __getitem__(self,idx):
        idx_data = self.data[idx]
        if self.transforms is not None:
            transformed_data =self.transforms(idx_data)
        else:
            transformed_data = idx_data
        idx_labels = self.labels[idx]
        return (transformed_data.unsqueeze(0).float(), idx_labels)



In [3]:
class Client():
    def __init__(self, train_data, test_data, client_id,  train_transforms, test_transforms, train_batch_size, test_batch_size, save_dir):
        self.trainset = ClientDataset(train_data, train_transforms)
        self.testset = ClientDataset(test_data, test_transforms)
        self.trainloader = DataLoader(self.trainset, batch_size = train_batch_size, shuffle=True, num_workers=4)
        self.testloader = DataLoader(self.testset, batch_size = test_batch_size, shuffle=False, num_workers=4)
        self.train_iterator = iter(self.trainloader)
        self.test_iterator = iter(self.testloader)
        self.client_id = client_id
        self.save_dir = os.path.join(save_dir, "init", "client_{}".format(client_id))

    def sample_batch(self, train=True):
        iterator = self.train_iterator if train else self.test_iterator
        try:
            (data, labels) = next(iterator)
        except StopIteration:
            if train:
                self.train_iterator = iter(self.trainloader)
                iterator = self.train_iterator
            else:
                self.test_iterator = iter(self.testloader)
                iterator = self.test_iterator
            (data, labels) = next(iterator)
        return (data, labels)


In [4]:
train_chunks, test_chunks = make_client_datasets(config)


In [5]:
import torchvision.transforms.functional as TF


class RotationTransform:
    """Rotate by one of the given angles."""

    def __init__(self, angle):
        self.angle = angle

    def __call__(self, x):
        return TF.rotate(x, self.angle)


config["train_batch"] = 512
config["test_batch"] = 512
client_loaders = []
for i in range(config["num_clusters"]):
    for j in range(config["num_clients_per_cluster"]):
        idx = i * config["num_clients_per_cluster"] + j
        x_train = train_chunks[idx][0]
        x_test = test_chunks[idx][0]
        if i >0:
            x_train = torch.rot90(x_train, i, [1, 2])
            x_test = torch.rot90(x_test, i, [1, 2])
        client_loaders.append(
            Client(
                (x_train, train_chunks[idx][1]),
                (x_test, test_chunks[idx][1]),
                idx,
                train_transforms=None,
                test_transforms=None,
                train_batch_size=config["train_batch"],
                test_batch_size=config["test_batch"],
                save_dir=config["results_dir"],
            )
        )


In [6]:
class SimpleLinear(torch.nn.Module):

    def __init__(self, h1=2048):
        super().__init__()
        self.fc1 = torch.nn.Linear(28*28, h1)
        self.fc2 = torch.nn.Linear(h1, 10)

    def forward(self, x):
        x = x.view(-1, 28 * 28)
        x = F.relu(self.fc1(x))
        # x = F.sigmoid(self.fc1(x))
        x = self.fc2(x)
        return x

In [7]:
# def train_fedavg_model(train_chunks, test_chunks, config):
# model = SimpleLinear()
# model.to(torch.device("cuda:0"))
# train_data = []
# train_labels = [chunk[1] for chunk in train_chunks]
# test_labels = [chunk[1] for chunk in test_chunks]
# test_data = []
# for i in range(config["num_clusters"]):
#     for j in range(config['num_clients_per_cluster']):
#         idx = i*config["num_clients_per_cluster"] + j
#         train_data.append()
# train_set = ClientDataset(data)

In [7]:
def calc_acc(model, device, client_data, train):
    loader = client_data.trainloader if train else client_data.testloader
    model.eval()
    acc = 0
    with torch.no_grad():
        for (X,Y) in loader:
            X = X.to(device)
            pred = model(X).argmax(axis=1).detach().cpu()
            acc += (Y == pred).float().mean()
    acc = acc/len(loader)
    acc *= 100.0
    return acc


class BaseTrainer(ABC):
    def __init__(self,config, save_dir):
        super(BaseTrainer, self).__init__()
        self.model = MODEL_LIST[config["model"]](**config["model_params"])
        self.save_dir = save_dir
        self.device = config["device"]
        self.loss_func = LOSSES[config["loss_func"]]
        self.config = config
        os.makedirs(self.save_dir, exist_ok=True)

    def train(self):
        raise NotImplementedError
    
    def test(self):
        raise NotImplementedError

    def load_model_weights(self):
        model_path  = os.path.join(self.save_dir, "model.pth")
        if os.path.exists(model_path):
            self.model.load_state_dict(torch.load(model_path))
        else:
            print("No model present at path : {}".format())

    def save_model_weights(self):
        model_path  = os.path.join(self.save_dir, "model.pth")
        torch.save(self.model.state_dict(), model_path)
    def save_metrics(self, train_loss, test_acc):
        torch.save({"train_loss": train_loss,  "test_acc" : test_acc}, os.path.join(self.save_dir,"metrics.pkl"))

class ClientTrainer(BaseTrainer):
    def __init__(self,  config, save_dir,client_id):
        super(ClientTrainer, self).__init__(config, save_dir)
        self.client_id = client_id
    
    def train(self, client_data):
        train_loss_list = []
        test_acc_list = []
        self.model.to(self.device)
        self.model.train()
        optimizer = OPTIMIZER_LIST[self.config["optimizer"]](self.model.parameters(), **self.config["optimizer_params"])
        for iteration in tqdm(range(self.config["iterations"])):
            self.model.zero_grad()
            (X,Y) = client_data.sample_batch(train=True)
            X = X.to(self.device)
            Y = Y.to(self.device)
            out = self.model(X)
            loss = self.loss_func(out, Y)
            loss.backward()
            optimizer.step()
            train_loss = loss.detach().cpu().numpy().item()
            train_loss_list.append(train_loss)
            test_acc = calc_acc(self.model, self.device, client_data, train=False)
            test_acc_list.append(test_acc)
            self.model.train()
            if iteration % self.config["save_freq"] == 0 or iteration == self.config["iterations"] - 1:
                self.save_model_weights()
                self.save_metrics(train_loss_list, test_acc_list)
            if iteration % self.config["print_freq"] == 0 or iteration == self.config["iterations"] - 1:
                print("Iteration : {} \n , Train Loss : {} \n, Test Acc : {} \n".format(iteration,  train_loss, test_acc))
                
        self.model.eval()
        self.model.cpu()


    def test(self, client_data):
        self.load_model_weights()
        self.model.eval()
        self.model.to(self.device)
        acc =  calc_acc(self.model, client_data)
        self.model.cpu()
        return acc


  
MODEL_LIST = {"lin" : SimpleLinear}
OPTIMIZER_LIST = {"sgd": optim.SGD, "adam": optim.Adam}
LOSSES = {"cross_entropy": nn.CrossEntropyLoss()}
# config["save_dir"] = os.path.join("./results")
config["iterations"] = 50
config["optimizer_params"] = {"lr":0.001}
config["save_freq"] = 2
config["print_freq"]  = 10
config["model"] = "lin"
config["optimizer"] = "adam"
config["loss_func"] = "cross_entropy"
#config["model_params"] = {"num_channels": 1 , "num_classes"  : 62}
config["model_params"] = {}
config["device"] = torch.device("cuda:0")
import pickle
client_trainers = [ClientTrainer(config,os.path.join(config["results_dir"], "init", "node_{}".format(i)), i) for i in range(config["num_clients"])]


for i in tqdm(range(config["num_clients"])):
    client_trainers[i].train(client_loaders[i])
    


  0%|          | 0/32 [00:00<?, ?it/s]

Iteration : 0 
 , Train Loss : 23.324874877929688 
, Test Acc : 33.73548126220703 





Iteration : 10 
 , Train Loss : 8.788775444030762 
, Test Acc : 76.26693725585938 





Iteration : 20 
 , Train Loss : 3.646134614944458 
, Test Acc : 87.07308197021484 





Iteration : 30 
 , Train Loss : 1.7003434896469116 
, Test Acc : 91.43390655517578 





Iteration : 40 
 , Train Loss : 1.07925283908844 
, Test Acc : 92.1500473022461 



100%|██████████| 50/50 [00:16<00:00,  3.03it/s]
  3%|▎         | 1/32 [00:20<10:43, 20.75s/it]

Iteration : 49 
 , Train Loss : 0.41405028104782104 
, Test Acc : 92.8183822631836 





Iteration : 0 
 , Train Loss : 25.738914489746094 
, Test Acc : 38.08824157714844 





Iteration : 10 
 , Train Loss : 15.305306434631348 
, Test Acc : 76.67485046386719 





Iteration : 20 
 , Train Loss : 1.612464427947998 
, Test Acc : 90.639404296875 





Iteration : 30 
 , Train Loss : 0.8251586556434631 
, Test Acc : 91.91498565673828 





Iteration : 40 
 , Train Loss : 0.8059536814689636 
, Test Acc : 92.66166687011719 



100%|██████████| 50/50 [00:15<00:00,  3.20it/s]
  6%|▋         | 2/32 [00:36<08:51, 17.73s/it]

Iteration : 49 
 , Train Loss : 0.26487964391708374 
, Test Acc : 93.28217315673828 





Iteration : 0 
 , Train Loss : 26.702606201171875 
, Test Acc : 26.684066772460938 





Iteration : 10 
 , Train Loss : 11.642060279846191 
, Test Acc : 77.07872009277344 





Iteration : 20 
 , Train Loss : 3.0946877002716064 
, Test Acc : 86.4087905883789 





Iteration : 30 
 , Train Loss : 1.111997127532959 
, Test Acc : 90.42681121826172 





Iteration : 40 
 , Train Loss : 0.8151707053184509 
, Test Acc : 92.17540740966797 



100%|██████████| 50/50 [00:15<00:00,  3.18it/s]
  9%|▉         | 3/32 [00:52<08:07, 16.82s/it]

Iteration : 49 
 , Train Loss : 0.2926771938800812 
, Test Acc : 93.16924285888672 





Iteration : 0 
 , Train Loss : 30.818038940429688 
, Test Acc : 41.38608169555664 





Iteration : 10 
 , Train Loss : 6.650596618652344 
, Test Acc : 84.64290618896484 





Iteration : 20 
 , Train Loss : 3.266239881515503 
, Test Acc : 90.13182830810547 





Iteration : 30 
 , Train Loss : 1.2674496173858643 
, Test Acc : 92.2324447631836 





Iteration : 40 
 , Train Loss : 0.6763841509819031 
, Test Acc : 94.66664123535156 



100%|██████████| 50/50 [00:15<00:00,  3.17it/s]
 12%|█▎        | 4/32 [01:07<07:39, 16.41s/it]

Iteration : 49 
 , Train Loss : 0.3778547942638397 
, Test Acc : 94.56697082519531 





Iteration : 0 
 , Train Loss : 35.26097106933594 
, Test Acc : 29.457040786743164 





Iteration : 10 
 , Train Loss : 7.968413352966309 
, Test Acc : 77.74705505371094 





Iteration : 20 
 , Train Loss : 4.6050190925598145 
, Test Acc : 85.2369155883789 





Iteration : 30 
 , Train Loss : 2.468611717224121 
, Test Acc : 89.99757385253906 





Iteration : 40 
 , Train Loss : 0.7983015775680542 
, Test Acc : 91.91094970703125 



100%|██████████| 50/50 [00:15<00:00,  3.17it/s]
 16%|█▌        | 5/32 [01:23<07:17, 16.19s/it]

Iteration : 49 
 , Train Loss : 0.8657839298248291 
, Test Acc : 93.23031616210938 





Iteration : 0 
 , Train Loss : 30.516916275024414 
, Test Acc : 23.25601577758789 





Iteration : 10 
 , Train Loss : 9.333032608032227 
, Test Acc : 68.90325927734375 





Iteration : 20 
 , Train Loss : 7.008359909057617 
, Test Acc : 84.33870697021484 





Iteration : 30 
 , Train Loss : 1.8374418020248413 
, Test Acc : 89.67205810546875 





Iteration : 40 
 , Train Loss : 1.7665431499481201 
, Test Acc : 91.33827209472656 



100%|██████████| 50/50 [00:15<00:00,  3.18it/s]
 19%|█▉        | 6/32 [01:39<06:56, 16.02s/it]

Iteration : 49 
 , Train Loss : 0.443351149559021 
, Test Acc : 91.8418197631836 





Iteration : 0 
 , Train Loss : 27.242996215820312 
, Test Acc : 38.77788543701172 





Iteration : 10 
 , Train Loss : 12.042145729064941 
, Test Acc : 82.96344757080078 





Iteration : 20 
 , Train Loss : 2.024482250213623 
, Test Acc : 87.41588592529297 





Iteration : 30 
 , Train Loss : 1.6573280096054077 
, Test Acc : 89.35978698730469 





Iteration : 40 
 , Train Loss : 0.5033729672431946 
, Test Acc : 90.97413635253906 



100%|██████████| 50/50 [00:15<00:00,  3.15it/s]
 22%|██▏       | 7/32 [01:55<06:39, 15.97s/it]

Iteration : 49 
 , Train Loss : 0.5590364933013916 
, Test Acc : 91.59868621826172 





Iteration : 0 
 , Train Loss : 24.693584442138672 
, Test Acc : 42.0146598815918 





Iteration : 10 
 , Train Loss : 10.213240623474121 
, Test Acc : 75.55597686767578 





Iteration : 20 
 , Train Loss : 4.133439064025879 
, Test Acc : 82.58607482910156 





Iteration : 30 
 , Train Loss : 1.26424241065979 
, Test Acc : 89.11665344238281 





Iteration : 40 
 , Train Loss : 1.4837745428085327 
, Test Acc : 90.87850189208984 



100%|██████████| 50/50 [00:16<00:00,  3.05it/s]
 25%|██▌       | 8/32 [02:11<06:26, 16.12s/it]

Iteration : 49 
 , Train Loss : 0.45162564516067505 
, Test Acc : 91.03924560546875 





Iteration : 0 
 , Train Loss : 23.724178314208984 
, Test Acc : 19.420055389404297 





Iteration : 10 
 , Train Loss : 8.921066284179688 
, Test Acc : 80.95846557617188 





Iteration : 20 
 , Train Loss : 5.8202643394470215 
, Test Acc : 85.59700012207031 





Iteration : 30 
 , Train Loss : 1.9501302242279053 
, Test Acc : 90.94764709472656 





Iteration : 40 
 , Train Loss : 1.061132788658142 
, Test Acc : 91.1734848022461 



100%|██████████| 50/50 [00:16<00:00,  3.09it/s]
 28%|██▊       | 9/32 [02:27<06:11, 16.14s/it]

Iteration : 49 
 , Train Loss : 0.697250485420227 
, Test Acc : 93.0177230834961 





Iteration : 0 
 , Train Loss : 30.511754989624023 
, Test Acc : 37.50634002685547 





Iteration : 10 
 , Train Loss : 10.951386451721191 
, Test Acc : 77.71247863769531 





Iteration : 20 
 , Train Loss : 2.114656925201416 
, Test Acc : 85.3498306274414 





Iteration : 30 
 , Train Loss : 1.0520790815353394 
, Test Acc : 89.78498840332031 





Iteration : 40 
 , Train Loss : 0.732140302658081 
, Test Acc : 92.58849334716797 



100%|██████████| 50/50 [00:16<00:00,  3.12it/s]
 31%|███▏      | 10/32 [02:43<05:54, 16.11s/it]

Iteration : 49 
 , Train Loss : 0.20676450431346893 
, Test Acc : 92.26297760009766 





Iteration : 0 
 , Train Loss : 30.972923278808594 
, Test Acc : 38.21845245361328 





Iteration : 10 
 , Train Loss : 10.42754077911377 
, Test Acc : 66.36419677734375 





Iteration : 20 
 , Train Loss : 3.2890427112579346 
, Test Acc : 82.97669982910156 





Iteration : 30 
 , Train Loss : 0.8900904059410095 
, Test Acc : 88.7565689086914 





Iteration : 40 
 , Train Loss : 0.8628529906272888 
, Test Acc : 90.35363006591797 



100%|██████████| 50/50 [00:16<00:00,  3.07it/s]
 34%|███▍      | 11/32 [03:00<05:39, 16.16s/it]

Iteration : 49 
 , Train Loss : 0.381736695766449 
, Test Acc : 91.16944885253906 





Iteration : 0 
 , Train Loss : 33.97298812866211 
, Test Acc : 37.69358444213867 





Iteration : 10 
 , Train Loss : 11.309563636779785 
, Test Acc : 74.97407531738281 





Iteration : 20 
 , Train Loss : 3.148554563522339 
, Test Acc : 89.77577209472656 





Iteration : 30 
 , Train Loss : 1.2280439138412476 
, Test Acc : 92.28429412841797 





Iteration : 40 
 , Train Loss : 0.9488826990127563 
, Test Acc : 93.52127075195312 



100%|██████████| 50/50 [00:16<00:00,  3.06it/s]
 38%|███▊      | 12/32 [03:16<05:24, 16.22s/it]

Iteration : 49 
 , Train Loss : 0.7185647487640381 
, Test Acc : 93.65148162841797 





Iteration : 0 
 , Train Loss : 22.169206619262695 
, Test Acc : 30.472206115722656 





Iteration : 10 
 , Train Loss : 9.616070747375488 
, Test Acc : 81.0276107788086 





Iteration : 20 
 , Train Loss : 3.338977813720703 
, Test Acc : 89.32926177978516 





Iteration : 30 
 , Train Loss : 2.463779926300049 
, Test Acc : 90.61808776855469 





Iteration : 40 
 , Train Loss : 0.850493311882019 
, Test Acc : 91.920166015625 



100%|██████████| 50/50 [00:16<00:00,  3.01it/s]
 41%|████      | 13/32 [03:33<05:10, 16.34s/it]

Iteration : 49 
 , Train Loss : 0.5699601173400879 
, Test Acc : 91.80724334716797 





Iteration : 0 
 , Train Loss : 25.733901977539062 
, Test Acc : 23.17765998840332 





Iteration : 10 
 , Train Loss : 14.198567390441895 
, Test Acc : 80.68883514404297 





Iteration : 20 
 , Train Loss : 4.543600082397461 
, Test Acc : 83.1455078125 





Iteration : 30 
 , Train Loss : 1.3814752101898193 
, Test Acc : 89.88465881347656 





Iteration : 40 
 , Train Loss : 1.440469741821289 
, Test Acc : 91.66378784179688 



100%|██████████| 50/50 [00:16<00:00,  3.03it/s]
 44%|████▍     | 14/32 [03:49<04:54, 16.39s/it]

Iteration : 49 
 , Train Loss : 0.5700271129608154 
, Test Acc : 92.54470825195312 





Iteration : 0 
 , Train Loss : 29.719743728637695 
, Test Acc : 35.07616424560547 





Iteration : 10 
 , Train Loss : 9.375974655151367 
, Test Acc : 76.63624572753906 





Iteration : 20 
 , Train Loss : 2.310260534286499 
, Test Acc : 87.79728698730469 





Iteration : 30 
 , Train Loss : 1.314182996749878 
, Test Acc : 87.55415344238281 





Iteration : 40 
 , Train Loss : 0.6634061336517334 
, Test Acc : 90.52245330810547 



100%|██████████| 50/50 [00:16<00:00,  2.98it/s]
 47%|████▋     | 15/32 [04:06<04:40, 16.51s/it]

Iteration : 49 
 , Train Loss : 0.3527931272983551 
, Test Acc : 90.73101043701172 





Iteration : 0 
 , Train Loss : 30.458248138427734 
, Test Acc : 30.936002731323242 





Iteration : 10 
 , Train Loss : 14.334659576416016 
, Test Acc : 70.54410552978516 





Iteration : 20 
 , Train Loss : 3.3295066356658936 
, Test Acc : 86.03141021728516 





Iteration : 30 
 , Train Loss : 1.7330418825149536 
, Test Acc : 90.65265655517578 





Iteration : 40 
 , Train Loss : 0.6370010375976562 
, Test Acc : 92.77458953857422 



100%|██████████| 50/50 [00:16<00:00,  3.00it/s]
 50%|█████     | 16/32 [04:23<04:24, 16.56s/it]

Iteration : 49 
 , Train Loss : 0.5972536206245422 
, Test Acc : 94.1590576171875 





Iteration : 0 
 , Train Loss : 29.472545623779297 
, Test Acc : 22.3313045501709 





Iteration : 10 
 , Train Loss : 14.859447479248047 
, Test Acc : 71.46766662597656 





Iteration : 20 
 , Train Loss : 5.275289058685303 
, Test Acc : 84.33870697021484 





Iteration : 30 
 , Train Loss : 1.5288817882537842 
, Test Acc : 90.72179412841797 





Iteration : 40 
 , Train Loss : 1.0096545219421387 
, Test Acc : 92.49688720703125 



100%|██████████| 50/50 [00:16<00:00,  3.02it/s]
 53%|█████▎    | 17/32 [04:39<04:08, 16.56s/it]

Iteration : 49 
 , Train Loss : 0.9474054574966431 
, Test Acc : 92.72273254394531 





Iteration : 0 
 , Train Loss : 35.2782096862793 
, Test Acc : 48.45478439331055 





Iteration : 10 
 , Train Loss : 10.125846862792969 
, Test Acc : 77.78968811035156 





Iteration : 20 
 , Train Loss : 2.491344928741455 
, Test Acc : 86.8074722290039 





Iteration : 30 
 , Train Loss : 1.3103411197662354 
, Test Acc : 90.13585662841797 





Iteration : 40 
 , Train Loss : 1.0006017684936523 
, Test Acc : 92.37071990966797 



100%|██████████| 50/50 [00:16<00:00,  2.98it/s]
 56%|█████▋    | 18/32 [04:56<03:52, 16.63s/it]

Iteration : 49 
 , Train Loss : 0.4722249507904053 
, Test Acc : 93.069580078125 





Iteration : 0 
 , Train Loss : 37.622249603271484 
, Test Acc : 37.667083740234375 





Iteration : 10 
 , Train Loss : 11.02817440032959 
, Test Acc : 70.59481048583984 





Iteration : 20 
 , Train Loss : 3.358163833618164 
, Test Acc : 85.44547271728516 





Iteration : 30 
 , Train Loss : 1.3385587930679321 
, Test Acc : 90.9781723022461 





Iteration : 40 
 , Train Loss : 1.1067789793014526 
, Test Acc : 92.201904296875 



100%|██████████| 50/50 [00:16<00:00,  2.98it/s]
 59%|█████▉    | 19/32 [05:13<03:36, 16.69s/it]

Iteration : 49 
 , Train Loss : 0.35387322306632996 
, Test Acc : 92.34939575195312 





Iteration : 0 
 , Train Loss : 22.799325942993164 
, Test Acc : 37.53687286376953 





Iteration : 10 
 , Train Loss : 15.100811004638672 
, Test Acc : 70.79127502441406 





Iteration : 20 
 , Train Loss : 2.9901037216186523 
, Test Acc : 87.9753189086914 





Iteration : 30 
 , Train Loss : 1.5035806894302368 
, Test Acc : 90.14507293701172 





Iteration : 40 
 , Train Loss : 1.0431963205337524 
, Test Acc : 91.41661834716797 



100%|██████████| 50/50 [00:17<00:00,  2.88it/s]
 62%|██████▎   | 20/32 [05:30<03:22, 16.89s/it]

Iteration : 49 
 , Train Loss : 0.3644329607486725 
, Test Acc : 92.24972534179688 





Iteration : 0 
 , Train Loss : 35.563716888427734 
, Test Acc : 29.76931381225586 





Iteration : 10 
 , Train Loss : 12.439555168151855 
, Test Acc : 75.35951232910156 





Iteration : 20 
 , Train Loss : 1.8136310577392578 
, Test Acc : 86.56952667236328 





Iteration : 30 
 , Train Loss : 1.7498838901519775 
, Test Acc : 89.16447448730469 





Iteration : 40 
 , Train Loss : 0.6068224906921387 
, Test Acc : 91.138916015625 



100%|██████████| 50/50 [00:17<00:00,  2.89it/s]
 66%|██████▌   | 21/32 [05:47<03:07, 17.01s/it]

Iteration : 49 
 , Train Loss : 0.31698018312454224 
, Test Acc : 91.48172760009766 





Iteration : 0 
 , Train Loss : 29.63682746887207 
, Test Acc : 44.56178283691406 





Iteration : 10 
 , Train Loss : 11.873008728027344 
, Test Acc : 78.60147094726562 





Iteration : 20 
 , Train Loss : 3.099485397338867 
, Test Acc : 87.38016510009766 





Iteration : 30 
 , Train Loss : 2.334475517272949 
, Test Acc : 89.31600189208984 





Iteration : 40 
 , Train Loss : 1.058392882347107 
, Test Acc : 91.3901138305664 



100%|██████████| 50/50 [00:16<00:00,  2.95it/s]
 69%|██████▉   | 22/32 [06:04<02:49, 16.99s/it]

Iteration : 49 
 , Train Loss : 0.33377841114997864 
, Test Acc : 92.7573013305664 





Iteration : 0 
 , Train Loss : 30.928892135620117 
, Test Acc : 26.02495765686035 





Iteration : 10 
 , Train Loss : 16.404281616210938 
, Test Acc : 68.07418823242188 





Iteration : 20 
 , Train Loss : 5.18757438659668 
, Test Acc : 81.7783203125 





Iteration : 30 
 , Train Loss : 1.0388504266738892 
, Test Acc : 89.99757385253906 





Iteration : 40 
 , Train Loss : 0.9917730689048767 
, Test Acc : 91.20806121826172 



100%|██████████| 50/50 [00:16<00:00,  2.98it/s]
 72%|███████▏  | 23/32 [06:21<02:32, 16.93s/it]

Iteration : 49 
 , Train Loss : 0.42748528718948364 
, Test Acc : 92.87023162841797 





Iteration : 0 
 , Train Loss : 34.327938079833984 
, Test Acc : 40.48384475708008 





Iteration : 10 
 , Train Loss : 10.93166446685791 
, Test Acc : 72.10948944091797 





Iteration : 20 
 , Train Loss : 3.3857204914093018 
, Test Acc : 85.99684143066406 





Iteration : 30 
 , Train Loss : 1.4827359914779663 
, Test Acc : 89.03427124023438 





Iteration : 40 
 , Train Loss : 0.6390799880027771 
, Test Acc : 92.60578155517578 



100%|██████████| 50/50 [00:16<00:00,  2.96it/s]
 75%|███████▌  | 24/32 [06:38<02:15, 16.92s/it]

Iteration : 49 
 , Train Loss : 0.5095298886299133 
, Test Acc : 92.01580810546875 





Iteration : 0 
 , Train Loss : 29.05426597595215 
, Test Acc : 38.856239318847656 





Iteration : 10 
 , Train Loss : 11.32693862915039 
, Test Acc : 73.06877136230469 





Iteration : 20 
 , Train Loss : 5.192742347717285 
, Test Acc : 86.2002182006836 





Iteration : 30 
 , Train Loss : 1.6473535299301147 
, Test Acc : 91.12969970703125 





Iteration : 40 
 , Train Loss : 0.7875978946685791 
, Test Acc : 92.33615112304688 



100%|██████████| 50/50 [00:16<00:00,  2.96it/s]
 78%|███████▊  | 25/32 [06:55<01:58, 16.91s/it]

Iteration : 49 
 , Train Loss : 0.3616253137588501 
, Test Acc : 92.92208862304688 





Iteration : 0 
 , Train Loss : 29.764774322509766 
, Test Acc : 31.17106819152832 





Iteration : 10 
 , Train Loss : 7.594303131103516 
, Test Acc : 72.28348541259766 





Iteration : 20 
 , Train Loss : 3.238492250442505 
, Test Acc : 84.06100463867188 





Iteration : 30 
 , Train Loss : 1.4408913850784302 
, Test Acc : 90.85199737548828 





Iteration : 40 
 , Train Loss : 0.5904083251953125 
, Test Acc : 93.02175903320312 



100%|██████████| 50/50 [00:16<00:00,  2.94it/s]
 81%|████████▏ | 26/32 [07:12<01:41, 16.93s/it]

Iteration : 49 
 , Train Loss : 0.2629304826259613 
, Test Acc : 93.90267944335938 





Iteration : 0 
 , Train Loss : 35.539031982421875 
, Test Acc : 41.949554443359375 





Iteration : 10 
 , Train Loss : 9.805353164672852 
, Test Acc : 75.61589813232422 





Iteration : 20 
 , Train Loss : 3.5133166313171387 
, Test Acc : 87.40666198730469 





Iteration : 30 
 , Train Loss : 1.5334112644195557 
, Test Acc : 89.72391510009766 





Iteration : 40 
 , Train Loss : 1.4822158813476562 
, Test Acc : 90.89579010009766 



100%|██████████| 50/50 [00:16<00:00,  2.94it/s]
 84%|████████▍ | 27/32 [07:29<01:24, 16.95s/it]

Iteration : 49 
 , Train Loss : 0.6228773593902588 
, Test Acc : 91.2558822631836 





Iteration : 0 
 , Train Loss : 41.644283294677734 
, Test Acc : 22.32611846923828 





Iteration : 10 
 , Train Loss : 22.498414993286133 
, Test Acc : 66.5208969116211 





Iteration : 20 
 , Train Loss : 4.464673042297363 
, Test Acc : 84.82093811035156 





Iteration : 30 
 , Train Loss : 2.212244749069214 
, Test Acc : 90.26203155517578 





Iteration : 40 
 , Train Loss : 1.0185352563858032 
, Test Acc : 91.22130584716797 



100%|██████████| 50/50 [00:17<00:00,  2.90it/s]
 88%|████████▊ | 28/32 [07:46<01:08, 17.05s/it]

Iteration : 49 
 , Train Loss : 0.6120910048484802 
, Test Acc : 91.7594223022461 





Iteration : 0 
 , Train Loss : 36.799983978271484 
, Test Acc : 28.65850830078125 





Iteration : 10 
 , Train Loss : 8.884716987609863 
, Test Acc : 64.05213165283203 





Iteration : 20 
 , Train Loss : 3.39878249168396 
, Test Acc : 86.43009948730469 





Iteration : 30 
 , Train Loss : 1.5886340141296387 
, Test Acc : 90.834716796875 





Iteration : 40 
 , Train Loss : 1.0456222295761108 
, Test Acc : 92.33210754394531 



100%|██████████| 50/50 [00:17<00:00,  2.90it/s]
 91%|█████████ | 29/32 [08:03<00:51, 17.11s/it]

Iteration : 49 
 , Train Loss : 0.6050530076026917 
, Test Acc : 93.264892578125 





Iteration : 0 
 , Train Loss : 26.50301742553711 
, Test Acc : 41.06056213378906 





Iteration : 10 
 , Train Loss : 17.844388961791992 
, Test Acc : 74.18878936767578 





Iteration : 20 
 , Train Loss : 4.717982769012451 
, Test Acc : 86.52574157714844 





Iteration : 30 
 , Train Loss : 1.4859853982925415 
, Test Acc : 90.09321594238281 





Iteration : 40 
 , Train Loss : 1.1841453313827515 
, Test Acc : 93.5385513305664 



100%|██████████| 50/50 [00:17<00:00,  2.91it/s]
 94%|█████████▍| 30/32 [08:21<00:34, 17.13s/it]

Iteration : 49 
 , Train Loss : 0.7361757755279541 
, Test Acc : 93.83354187011719 





Iteration : 0 
 , Train Loss : 25.07516098022461 
, Test Acc : 35.128021240234375 





Iteration : 10 
 , Train Loss : 12.367348670959473 
, Test Acc : 81.64810943603516 





Iteration : 20 
 , Train Loss : 3.947296142578125 
, Test Acc : 86.7994155883789 





Iteration : 30 
 , Train Loss : 1.348787546157837 
, Test Acc : 91.14295959472656 





Iteration : 40 
 , Train Loss : 1.295526146888733 
, Test Acc : 92.397216796875 



100%|██████████| 50/50 [00:17<00:00,  2.88it/s]
 97%|█████████▋| 31/32 [08:38<00:17, 17.20s/it]

Iteration : 49 
 , Train Loss : 0.5586777329444885 
, Test Acc : 93.264892578125 





Iteration : 0 
 , Train Loss : 26.973094940185547 
, Test Acc : 34.42108917236328 





Iteration : 10 
 , Train Loss : 14.23310375213623 
, Test Acc : 73.28136444091797 





Iteration : 20 
 , Train Loss : 4.280537128448486 
, Test Acc : 85.22884368896484 





Iteration : 30 
 , Train Loss : 1.8137036561965942 
, Test Acc : 89.28143310546875 





Iteration : 40 
 , Train Loss : 0.9283969402313232 
, Test Acc : 90.22746276855469 



100%|██████████| 50/50 [00:17<00:00,  2.86it/s]
100%|██████████| 32/32 [08:56<00:00, 16.75s/it]

Iteration : 49 
 , Train Loss : 0.4980047941207886 
, Test Acc : 90.9954605102539 






In [163]:
import networkx as nx
G = nx.Graph()
G.add_nodes_from(range(config["num_clients"]))
import itertools
def model_weights_diff(w_1, w_2):
    norm_sq = 0
    assert w_1.keys() == w_2.keys(), "Model weights have different keys"
    for key in w_1.keys():
        norm_sq  += (w_1[key] - w_2[key]).norm()**2
    return np.sqrt(norm_sq)
wt = client_trainers[0].model.state_dict()
# thresh = 0
# for key in wt.keys():
#     thresh += wt[key].norm()**2
# print(torch.sqrt(thresh))
# thresh = 37.68

all_pairs = list(itertools.combinations(range(config["num_clients"]),2))
arr = []
for pair in all_pairs:
    w_1  = client_trainers[pair[0]].model.state_dict()
    w_2 = client_trainers[pair[1]].model.state_dict()
    norm_diff = model_weights_diff(w_1, w_2)
    arr.append(norm_diff)
#thresh = torch.mean(torch.tensor(arr))
thresh = arr[torch.tensor(arr).argsort()[int(0.36*len(arr))-1]]
for i in range(len(all_pairs)):
    if arr[i] < thresh:
        G.add_edge(all_pairs[i][0], all_pairs[i][1])
G = G.to_undirected()
#cliques = list(nx.algorithms.clique.enumerate_all_cliques(G))


In [106]:
arr[torch.tensor(arr).argsort()[int(0.4*len(arr))]]

tensor(38.1052)

In [164]:
for node in G[0]:
    print(node)

1
4
9
17
18
20
22
23
28


In [165]:
adj = nx.convert_matrix.to_numpy_array(G)

In [108]:
adj.sum(axis=1)

array([ 9., 31., 11.,  8., 14.,  2., 11., 10.,  4., 26., 10.,  3.,  2.,
        7., 10., 20.,  8., 26., 17., 14., 25., 14., 24., 25.,  9.,  1.,
       12.,  4., 22.,  5.,  6.,  4.])

In [166]:
clustering = []
def correlation_clustering(G):
    global clustering
    if len(G.nodes) == 0:
        return
    else:
        cluster = []
        new_cluster_pivot = random.sample(G.nodes,1)[0]
        cluster.append(new_cluster_pivot)
        neighbors = G[new_cluster_pivot].copy()
        for node in neighbors:
            cluster.append(node)
            G.remove_node(node)
        G.remove_node(new_cluster_pivot)
        clustering.append(cluster)
        correlation_clustering(G)
correlation_clustering(G)

since Python 3.9 and will be removed in a subsequent version.
  new_cluster_pivot = random.sample(G.nodes,1)[0]


In [167]:
clustering

[[24, 1, 9, 15, 18, 20, 23, 28],
 [26, 6, 17, 21, 22],
 [25],
 [30],
 [11],
 [0, 4],
 [19, 2, 7, 14, 29],
 [5],
 [13],
 [31],
 [27],
 [10],
 [3],
 [12],
 [16],
 [8]]

In [169]:
#config["t"] = 7
#t = config["t"]
clusters = [cluster  for cluster in clustering]
cluster_map = {i: clusters[i] for i in range(len(clusters))}
beta = 0.3

In [174]:
class ClusterTrainer(BaseTrainer):
    def __init__(self,  config, save_dir,cluster_id):
        super(ClusterTrainer, self).__init__(config, save_dir)
        self.cluster_id = cluster_id
    
    def train(self, client_data_list):
        num_clients = len(client_data_list)

        train_loss_list = []
        test_acc_list = []
        self.model.to(self.device)
        self.model.train()
        optimizer = OPTIMIZER_LIST[self.config["optimizer"]](self.model.parameters(), **self.config["optimizer_params"])
        for iteration in tqdm(range(self.config["iterations"])):
            trmean_buffer = {}
            for idx, param in self.model.named_parameters():
                trmean_buffer[idx] = []
            train_loss = 0
            for client in client_data_list:
                optimizer.zero_grad()
                (X,Y) = client.sample_batch()
                X = X.to(config["device"])
                Y = Y.to(config["device"])
                loss_func = nn.CrossEntropyLoss()
                out = self.model(X)
                loss = loss_func(out,Y)
                loss.backward()
                train_loss += loss.detach().cpu().numpy().item()
                with torch.no_grad():
                    for idx, param in self.model.named_parameters():
                        trmean_buffer[idx].append(param.grad.clone())
            train_loss = train_loss/num_clients
            optimizer.zero_grad()
            for idx, param in self.model.named_parameters():
                sorted, _  = torch.sort(torch.stack(trmean_buffer[idx], dim=0), dim=0)
                new_grad = sorted[int(beta*num_clients):int((1-beta)*num_clients),...].mean(dim=0)
                param.grad = new_grad
                trmean_buffer[idx] = []
            optimizer.step()
            
            train_loss_list.append(train_loss)
            test_acc = 0
            for client_data in client_data_list:
                test_acc += calc_acc(self.model, self.device, client_data, train=False)
            test_acc = test_acc/num_clients
            test_acc_list.append(test_acc)
            self.model.train()
            if iteration % self.config["save_freq"] == 0 or iteration == self.config["iterations"] - 1:
                self.save_model_weights()
                self.save_metrics(train_loss_list, test_acc_list)
            if iteration % self.config["print_freq"] == 0:
                print("Iteration : {} \n , Train Loss : {} \n, Test Acc : {} \n".format(iteration,  train_loss, test_acc))
                
        self.model.eval()
        self.model.cpu()


    def test(self, client_data_list):
        self.load_model_weights()
        self.model.eval()
        self.model.to(self.device)
        test_acc = 0
        for client_data in client_data_list:
            test_acc += calc_acc(self.model, self.device, client_data, train=False)
        self.model.cpu()
        return test_acc


In [175]:
%%time 
config["refine_steps"] = 3
for refine_step in tqdm(range(config["refine_steps"])):
    beta = 0.3
    cluster_trainers = []
    for cluster_id in tqdm(cluster_map.keys()):
        cluster_clients = [client_loaders[i] for i in cluster_map[cluster_id]]
        cluster_trainer = ClusterTrainer(config, os.path.join(config['results_dir'], "refine_{}".format(refine_step), "cluster_{}".format(cluster_id)), cluster_id)
        cluster_trainer.train(cluster_clients)
        cluster_trainers.append(cluster_trainer)
    with open(os.path.join(config["results_dir"],"refine_{}".format(refine_step), "cluster_maps.pkl"), 'wb') as handle:
        pickle.dump(cluster_map, handle, protocol=pickle.HIGHEST_PROTOCOL)
    cluster_map_recluster = {}
    for key in cluster_map.keys():
        cluster_map_recluster[key] = []

    for i in tqdm(range(config["num_clients"])):
        w_node = client_trainers[i].model.state_dict()
        norm_diff = np.infty
        new_cluster_id = 0
        for cluster_id in cluster_map.keys():
            w_cluster = cluster_trainers[cluster_id].model.state_dict()
            curr_norm_diff = model_weights_diff(w_node, w_cluster)
            if norm_diff > curr_norm_diff:
                new_cluster_id = cluster_id
                norm_diff = curr_norm_diff
        
        cluster_map_recluster[new_cluster_id].append(i)
    cluster_map = cluster_map_recluster


    G = nx.Graph()
    G.add_nodes_from(cluster_map.keys())

    all_pairs = list(itertools.combinations(cluster_map.keys(),2))
    for pair in tqdm(all_pairs):
        w_1  = cluster_trainers[pair[0]].model.state_dict()
        w_2 = cluster_trainers[pair[1]].model.state_dict()
        norm_diff = model_weights_diff(w_1, w_2)
        if norm_diff < thresh:
            G.add_edge(pair[0], pair[1])
    G = G.to_undirected()
    clustering = []        
    correlation_clustering(G)
    merge_clusters = [cluster  for cluster in clustering]
    #merge_cluster_map = {i: clusters[i] for i in range(len(clusters))}
    #clusters = list(nx.algorithms.clique.enumerate_all_cliques(G))
    cluster_map_new = {}
    for i in range(len(clusters)):
        cluster_map_new[i] = []
        for j in merge_clusters[i]:
            cluster_map_new[i] += cluster_map[j]
    cluster_map = cluster_map_new


  0%|          | 0/3 [00:00<?, ?it/s]
[A
[A

Iteration : 0 
 , Train Loss : 25.816818237304688 
, Test Acc : 18.80019760131836 




[A
[A
[A
[A
[A
[A
[A
[A
[A
[A

Iteration : 10 
 , Train Loss : 32.171459436416626 
, Test Acc : 44.74650573730469 




[A
[A
[A
[A
[A
[A

In [83]:
with open(os.path.join(config["results_dir"],"refine_{}".format(refine_step), "cluster_maps.pkl"), 'wb') as handle:
    pickle.dump(cluster_map, handle, protocol=pickle.HIGHEST_PROTOCOL)
cluster_map_recluster = {}
for key in cluster_map.keys():
    cluster_map_recluster[key] = []

for i in tqdm(range(config["num_clients"])):
    w_node = client_trainers[i].model.state_dict()
    norm_diff = np.infty
    new_cluster_id = 0
    for cluster_id in cluster_map.keys():
        w_cluster = cluster_trainers[cluster_id].model.state_dict()
        curr_norm_diff = model_weights_diff(w_node, w_cluster)
        if norm_diff > curr_norm_diff:
            new_cluster_id = cluster_id
            norm_diff = curr_norm_diff
    
    cluster_map_recluster[new_cluster_id].append(i)
cluster_map = cluster_map_recluster


G = nx.Graph()
G.add_nodes_from(cluster_map.keys())

all_pairs = list(itertools.combinations(cluster_map.keys(),2))
for pair in tqdm(all_pairs):
    w_1  = cluster_trainers[pair[0]].model.state_dict()
    w_2 = cluster_trainers[pair[1]].model.state_dict()
    norm_diff = model_weights_diff(w_1, w_2)
    if norm_diff < thresh:
        G.add_edge(pair[0], pair[1])
        
clusters = list(nx.algorithms.clique.enumerate_all_cliques(G))


100%|██████████| 32/32 [00:05<00:00,  5.41it/s]
100%|██████████| 12246/12246 [00:14<00:00, 820.48it/s]


In [85]:
cluster_map_new = {}
for i in range(len(clusters)):
    cluster_map_new[i] = []
    for j in clusters[i]:
        cluster_map_new[i] += cluster_map[j]
cluster_map = cluster_map_new

In [88]:
len(clusters)

668