In [1]:
%load_ext autoreload
%autoreload 2

In [2]:




import torch
import numpy as np
from torch.utils.data import Dataset, DataLoader 
from abc import ABC
from tqdm import tqdm
import torchvision
import os
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import random
import json
from collections import OrderedDict


In [3]:
config = {}
config["seed"] = 46
seed = config["seed"]
os.environ['PYTHONHASHSEED'] = str(seed)
# Torch RNG
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
# Python RNG
np.random.seed(seed)
random.seed(seed)


In [4]:
config["participation_ratio"] = 0.05/6
#config["total_num_clients_per_cluster"] = 80
#config["num_clients_per_cluster"] = int(config["participation_ratio"]*config["total_num_clients_per_cluster"])
#config["num_clusters"] = 4
#config["num_clients"] = config["num_clients_per_cluster"]*config["num_clusters"]
config["dataset"] = "femnist"
#DATASET_LIB = {"mnist" : torchvision.datasets.MNIST, "emnist": torchvision.datasets.EMNIST, "cifar10": torchvision.datasets.CIFAR10}
config["dataset_dir"] = "/base_vol/femnist/data"
config["results_dir"] = "./experiments/results"
config["results_dir"] = os.path.join(config["results_dir"], config["dataset"] + "ifca", "seed_{}".format(seed))
os.makedirs(config["results_dir"], exist_ok=True)


In [5]:
from collections import defaultdict
def read_dir(data_dir):
    clients = []
    groups = []
    data = defaultdict(lambda : None)

    files = os.listdir(data_dir)
    files = [f for f in files if f.endswith('.json')]
    for f in files:
        file_path = os.path.join(data_dir,f)
        with open(file_path, 'r') as inf:
            cdata = json.load(inf)
        clients.extend(cdata['users'])
        if 'hierarchies' in cdata:
            groups.extend(cdata['hierarchies'])
        data.update(cdata['user_data'])

    clients = list(sorted(data.keys()))
    return clients, groups, data


def read_data(train_data_dir, test_data_dir):
    '''parses data in given train and test data directories
    assumes:
    - the data in the input directories are .json files with 
        keys 'users' and 'user_data'
    - the set of train set users is the same as the set of test set users
    
    Return:
        clients: list of client ids
        groups: list of group ids; empty list if none found
        train_data: dictionary of train data
        test_data: dictionary of test data
    '''
    train_clients, train_groups, train_data = read_dir(train_data_dir)
    test_clients, test_groups, test_data = read_dir(test_data_dir)

    assert train_clients == test_clients
    assert train_groups == test_groups

    return train_clients, train_groups, train_data, test_data


In [6]:
config["total_clients"], _, train_data, test_data = read_data(os.path.join(config["dataset_dir"],"train"), os.path.join(config["dataset_dir"],"test"))


In [7]:
class ClientDataset(Dataset):
    def __init__(self, data,transforms = None):
        super(ClientDataset,self).__init__()
        self.data = data[0]
        self.labels = data[1]
        self.transforms = transforms

    def __len__(self):
        return len(self.data)
    
    def __getitem__(self,idx):
        idx_data = self.data[idx].reshape(28,28)
        if self.transforms is not None:
            transformed_data =self.transforms(idx_data)
        else:
            transformed_data = idx_data
        idx_labels = self.labels[idx]
        return (transformed_data.float(), idx_labels)

class Client():
    def __init__(self, train_data, test_data, client_id,  train_transforms, test_transforms, train_batch_size, test_batch_size, save_dir):
        self.trainset = ClientDataset(train_data, train_transforms)
        self.testset = ClientDataset(test_data, test_transforms)
        self.trainloader = DataLoader(self.trainset, batch_size = train_batch_size, shuffle=True, num_workers=1)
        self.testloader = DataLoader(self.testset, batch_size = test_batch_size, shuffle=False, num_workers=1)
        self.train_iterator = iter(self.trainloader)
        self.test_iterator = iter(self.testloader)
        self.client_id = client_id
        self.save_dir = os.path.join(save_dir, "init", "client_{}".format(client_id))

    def sample_batch(self, train=True):
        iterator = self.train_iterator if train else self.test_iterator
        try:
            (data, labels) = next(iterator)
        except StopIteration:
            if train:
                self.train_iterator = iter(self.trainloader)
                iterator = self.train_iterator
            else:
                self.test_iterator = iter(self.testloader)
                iterator = self.test_iterator
            (data, labels) = next(iterator)
        return (data, labels)


In [8]:
## Generate new clients
selected_clients_path = os.path.join("./experiments/results", config["dataset"], "seed_{}".format(config["seed"]),"selected_clients.json")
if os.path.exists(selected_clients_path):
    with open(selected_clients_path, "r") as f:
        config["selected_clients"] = json.load(f)
else:
    config["selected_clients"] = random.sample(config["total_clients"], config["num_clients"])
    with open(selected_clients_path, "w") as f:
        json.dump(config["selected_clients"], f)
with open(os.path.join(config["results_dir"], "selected_clients.json"), "w") as f:
    json.dump(config["selected_clients"], f)
random.shuffle(config["selected_clients"])

In [9]:
config["train_batch"] = 64
config["test_batch"] = 512
config["num_clients"] = int(len(config["total_clients"])/6)


In [10]:
client_loaders = []
for client_id in config["selected_clients"]:
        client_loaders.append(
            Client(
                (np.array(train_data[client_id]['x']), np.array(train_data[client_id]['y'])),
                (np.array(test_data[client_id]['x']), np.array(test_data[client_id]['y'])),
                client_id,
                train_transforms=torchvision.transforms.ToTensor(),
                test_transforms=torchvision.transforms.ToTensor(),
                train_batch_size=config["train_batch"],
                test_batch_size=config["test_batch"],
                save_dir=config["results_dir"],
            )
        )


In [11]:
class SimpleCNN(torch.nn.Module):

    def __init__(self, h1=2048):
        super().__init__()
        self.conv1 = torch.nn.Conv2d(1, 32, kernel_size = (5,5), padding="same")
        self.pool1 = torch.nn.MaxPool2d((2,2), stride=2)
        self.conv2 = torch.nn.Conv2d(32, 64, kernel_size= (5,5), padding = "same")
        self.pool2 = torch.nn.MaxPool2d((2,2), stride=2)
        self.fc1 = torch.nn.Linear(64*7*7, 2048)
        self.fc2 = torch.nn.Linear(2048, 62)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.pool1(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = self.pool2(x)
        x = x.flatten(start_dim=1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x


In [12]:
def set_weights(model):
    model_wt = torch.load('/base_vol/model_wt_dict.pt')
    new_wts = OrderedDict()
    new_wts['fc2.weight'] = torch.Tensor(model_wt["dense_1/kernel"]).t()
    new_wts['fc2.bias'] = torch.Tensor(model_wt["dense_1/bias"])
    new_wts['fc1.weight'] = torch.Tensor(model_wt["dense/kernel"]).t()
    new_wts['fc1.bias'] = torch.Tensor(model_wt["dense/bias"])
    new_wts["conv1.weight"] = torch.Tensor(model_wt["conv2d/kernel"]).permute(3,2,0,1)
    new_wts["conv2.weight"] = torch.Tensor(model_wt["conv2d_1/kernel"]).permute(3,2,0,1)
    new_wts["conv1.bias"] = torch.Tensor(model_wt["conv2d/bias"])
    new_wts["conv2.bias"] = torch.Tensor(model_wt["conv2d_1/bias"])
    model.load_state_dict(new_wts)
    return freeze_layers(model)
def freeze_layers(model):
    model.conv1.weight.requires_grad =False
    model.conv2.weight.requires_grad =False
    model.fc1.weight.requires_grad =True
    model.fc2.weight.requires_grad =True
    model.conv1.bias.requires_grad =False
    model.conv2.bias.requires_grad =False
    model.fc1.bias.requires_grad =True
    model.fc2.bias.requires_grad =True
    return model

In [39]:
def calc_acc(model, device, client_data, train):
    loader = client_data.trainloader if train else client_data.testloader
    model.eval()
    model.to(device)
    num_corr = 0
    tot_num = 0
    with torch.no_grad():
        for (X,Y) in loader:
            X = X.to(device)
            pred = model(X).argmax(axis=1).detach().cpu()
            num_corr += (Y == pred).float().sum()
            tot_num += Y.shape[0]
    acc = num_corr/tot_num
    acc *= 100.0
    return acc


In [15]:
class BaseTrainer(ABC):
    def __init__(self,config, save_dir):
        super(BaseTrainer, self).__init__()
        self.model = set_weights(MODEL_LIST[config["model"]](**config["model_params"]))
        self.save_dir = save_dir
        self.device = config["device"]
        self.loss_func = LOSSES[config["loss_func"]]
        self.config = config

    def train(self):
        raise NotImplementedError
    
    def test(self):
        raise NotImplementedError

    def load_model_weights(self):
        model_path  = os.path.join(self.save_dir, "model.pth")
        if os.path.exists(model_path):
            self.model.load_state_dict(torch.load(model_path))
        else:
            print("No model present at path : {}".format())

    def save_model_weights(self):
        os.makedirs(self.save_dir, exist_ok=True)

        model_path  = os.path.join(self.save_dir, "model.pth")
        torch.save(self.model.state_dict(), model_path)
    def save_metrics(self, train_loss, test_acc, iteration):
        os.makedirs(self.save_dir, exist_ok=True)
        torch.save({"train_loss": train_loss,  "test_acc" : test_acc}, os.path.join(self.save_dir,"metrics_{}.pkl".format(iteration)))

class ClusterTrainer(BaseTrainer):
    def __init__(self,  config, save_dir, cluster_id):
        super(ClusterTrainer, self).__init__(config, save_dir)
        self.cluster_id = cluster_id
        
    def train(self, client_data_list):
        num_clients = len(client_data_list)

        train_loss_list = []
        test_acc_list = []
        self.model.to(self.device)
        self.model.train()
        
        
        optimizer = OPTIMIZER_LIST[self.config["optimizer"]](self.model.parameters(), **self.config["optimizer_params"])
        #eff_num_workers = int(num_clients/(1 - 2*beta))
        # if eff_num_workers > 0:
        #     eff_batch_size = self.config["train_batch"]/eff_num_workers
        #     for i in range(num_clients):
        #         client_data_list[i].trainloader.batch_size = eff_batch_size
                
        for iteration in tqdm(range(self.config["iterations"])):
            trmean_buffer = {}
            for idx, param in self.model.named_parameters():
                if param.requires_grad:
                    trmean_buffer[idx] = []
            train_loss = 0
            #optimizer.zero_grad(set_to_none=True)

            for client in client_data_list:
                #if eff_num_workers>0:
                optimizer.zero_grad(set_to_none=True)
                (X,Y) = client.sample_batch()
                X = X.to(config["device"])
                Y = Y.to(config["device"])
                loss_func = nn.CrossEntropyLoss()
                out = self.model(X)
                loss = loss_func(out,Y)
                loss.backward()
                train_loss += loss.detach().cpu().numpy().item()
                
                with torch.no_grad():
                    for idx, param in self.model.named_parameters():
                        if param.requires_grad:
                            trmean_buffer[idx].append(param.grad.clone())
            train_loss = train_loss/num_clients
            optimizer.zero_grad()
            
            start_idx = 0
            end_idx = num_clients


            for idx, param in self.model.named_parameters():
                if param.requires_grad:
                    sorted, _  = torch.sort(torch.stack(trmean_buffer[idx], dim=0), dim=0)
                    new_grad = sorted[start_idx:end_idx,...].mean(dim=0)
                    param.grad = new_grad
                    trmean_buffer[idx] = []
            optimizer.step()
            
            train_loss_list.append(train_loss)
            test_acc = 0
            for client_data in client_data_list:
                test_acc += calc_acc(self.model, self.device, client_data, train=False)
            test_acc = test_acc/num_clients
            test_acc_list.append(test_acc)
            self.model.train()
            if iteration % self.config["save_freq"] == 0 or iteration == self.config["iterations"] - 1:
                self.save_model_weights()
                self.save_metrics(train_loss_list, test_acc_list, iteration)
            if iteration % self.config["print_freq"] == 0 or iteration == self.config["iterations"] - 1:
                print("Iteration : {} \n , Train Loss : {} \n, Test Acc : {} \n".format(iteration,  train_loss, test_acc))
                
        self.model.eval()
        self.model.cpu()


    def test(self, client_data_list):
        self.load_model_weights()
        self.model.eval()
        self.model.to(self.device)
        test_acc = 0
        for client_data in client_data_list:
            test_acc += calc_acc(self.model, self.device, client_data, train=False)
        self.model.cpu()
        return test_acc


def avg_acc(model_wts, client_data_list):
    orig = model_wts[0]
    if len(model_wts) > 0:
        for wt in model_wts[1:]:
            for key in orig.keys():
                if orig[key].dtype == torch.float32:
                    orig[key] += wt[key] 
        for key in orig.keys():
            if orig[key].dtype == torch.float32:
                orig[key] = orig[key]/len(model_wts)
    model = SimpleCNN()
    model.load_state_dict(orig)
    model.to(memory_format = torch.channels_last).cuda()
    test_acc = 0
    for client_data in client_data_list:
        test_acc += calc_acc(model, torch.device("cuda:0"), client_data, train=False)
    test_acc = test_acc/len(client_data_list)
    return test_acc, orig



In [43]:
def init_cluster_map(num_clusters, client_list):
    cluster_map = {}
    for i in range(num_clusters):
        cluster_map[i] = []
    for i, _ in enumerate(client_list):
        cluster_map[i%num_clusters].append(i)
    return cluster_map

In [44]:
config["num_clusters"] = 2
cluster_map = init_cluster_map(config["num_clusters"], config["selected_clients"])
  
MODEL_LIST = {"cnn" : SimpleCNN}
OPTIMIZER_LIST = {"sgd": optim.SGD, "adam": optim.Adam}
LOSSES = {"cross_entropy": nn.CrossEntropyLoss()}
# config["save_dir"] = os.path.join("./results")
config["iterations"] = 100
config["optimizer_params"] = {"lr":0.001}
config["save_freq"] = 2
config["print_freq"]  = 20
config["model"] = "cnn"
config["optimizer"] = "adam"
config["loss_func"] = "cross_entropy"
#config["model_params"] = {"num_channels": 1 , "num_classes"  : 62}
config["model_params"] = {}
config["device"] = torch.device("cuda:0")
import pickle




#for i in tqdm(range(len(config["selected_clients"]))):
#    client_trainers[i].train(client_loaders[i])


In [45]:
config["num_rounds"] = 1
cluster_trainers = []
for cluster_id in cluster_map.keys():
    cluster_trainers.append(ClusterTrainer(config,"", cluster_id))
    

In [46]:
def calc_loss(model, device, client_data, train,loss_func):
    loader = client_data.trainloader if train else client_data.testloader
    model.eval()
    model.to(device)
    tot_loss = 0
    tot_num = 0
    with torch.no_grad():
        for (X,Y) in loader:
            X = X.to(device)
            out = model(X).detach().cpu()
            loss = loss_func(out,Y).item()
            tot_loss += loss
            tot_num += Y.shape[0]
    avg_loss = tot_loss/tot_num
    return avg_loss
def recluster(config, cluster_trainers, client_loaders):
    new_map = {}
    for i in range(len(cluster_trainers)):
        new_map[i] = []
    for client_id, client in enumerate(client_loaders):
        best_loss = np.infty
        best_cluster_idx = 0
        for cluster_id, trainer in enumerate(cluster_trainers):
            client_loss = calc_loss(trainer.model, config["device"], client, train=True, loss_func = nn.CrossEntropyLoss())
            if best_loss > client_loss:
                best_loss = client_loss
                best_cluster_idx = cluster_id
        new_map[best_cluster_idx].append(client_id)
    return new_map

In [48]:
for round_idx in tqdm(range(config["num_rounds"])):
    cluster_map = recluster(config, cluster_trainers, client_loaders)

    with open(os.path.join(config["results_dir"],"round_{}".format(round_idx), "cluster_maps.pkl"), 'wb') as handle:
            pickle.dump(cluster_map, handle, protocol=pickle.HIGHEST_PROTOCOL)
    for cluster_id, cluster_clients in tqdm(cluster_map.items()):
        cluster_clients = [client_loaders[i] for i in cluster_map[cluster_id]]
        cluster_trainers[cluster_id].save_dir = os.path.join(config['results_dir'], "round_{}".format(round_idx), "cluster_{}".format(cluster_id))
        cluster_trainers[cluster_id].train(cluster_clients)
    if round_idx == config["num_rounds"]-1:
        with open(os.path.join(config["results_dir"], "final_cluster_map.pkl"), 'wb') as handle:
            pickle.dump(cluster_map, handle, protocol=pickle.HIGHEST_PROTOCOL)

  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/2 [00:00<?, ?it/s][A

  0%|                                                   | 0/100 [00:00<?, ?it/s][A[A

  1%|▍                                          | 1/100 [00:04<07:49,  4.74s/it][A[A

Iteration : 0 
 , Train Loss : 0.4045368963852525 
, Test Acc : 39.59782028198242 





  2%|▊                                          | 2/100 [00:06<05:14,  3.21s/it][A[A

  3%|█▎                                         | 3/100 [00:09<04:57,  3.07s/it][A[A

  4%|█▋                                         | 4/100 [00:12<04:44,  2.97s/it][A[A

  5%|██▏                                        | 5/100 [00:15<04:46,  3.01s/it][A[A

  6%|██▌                                        | 6/100 [00:18<04:31,  2.89s/it][A[A

  7%|███                                        | 7/100 [00:22<04:53,  3.15s/it][A[A

  8%|███▍                                       | 8/100 [00:24<04:20,  2.84s/it][A[A

  9%|███▊                                       | 9/100 [00:27<04:25,  2.91s/it][A[A

 10%|████▏                                     | 10/100 [00:30<04:19,  2.89s/it][A[A

 11%|████▌                                     | 11/100 [00:37<06:28,  4.37s/it][A[A

 12%|█████                                     | 12/100 [00:39<05:24,  3.69s/it][A[A

 13%|█████▍                   

Iteration : 20 
 , Train Loss : 0.4938193145208061 
, Test Acc : 72.2472152709961 





 22%|█████████▏                                | 22/100 [01:34<04:20,  3.34s/it][A[A

 23%|█████████▋                                | 23/100 [01:37<04:06,  3.21s/it][A[A

 24%|██████████                                | 24/100 [01:39<03:32,  2.80s/it][A[A

 25%|██████████▌                               | 25/100 [01:43<03:54,  3.12s/it][A[A

 26%|██████████▉                               | 26/100 [01:45<03:41,  3.00s/it][A[A

 27%|███████████▎                              | 27/100 [01:49<03:40,  3.03s/it][A[A

 28%|███████████▊                              | 28/100 [01:51<03:27,  2.88s/it][A[A

 29%|████████████▏                             | 29/100 [01:54<03:28,  2.94s/it][A[A

 30%|████████████▌                             | 30/100 [01:56<03:09,  2.70s/it][A[A

 31%|█████████████                             | 31/100 [02:01<03:37,  3.16s/it][A[A

 32%|█████████████▍                            | 32/100 [02:03<03:13,  2.85s/it][A[A

 33%|█████████████▊           

Iteration : 40 
 , Train Loss : 0.33816722920164466 
, Test Acc : 76.42537689208984 





 42%|█████████████████▋                        | 42/100 [02:32<02:41,  2.78s/it][A[A

 43%|██████████████████                        | 43/100 [02:35<02:54,  3.06s/it][A[A

 44%|██████████████████▍                       | 44/100 [02:37<02:35,  2.78s/it][A[A

 45%|██████████████████▉                       | 45/100 [02:41<02:38,  2.89s/it][A[A

 46%|███████████████████▎                      | 46/100 [02:44<02:42,  3.00s/it][A[A

 47%|███████████████████▋                      | 47/100 [02:47<02:36,  2.96s/it][A[A

 48%|████████████████████▏                     | 48/100 [02:49<02:20,  2.70s/it][A[A

 49%|████████████████████▌                     | 49/100 [02:53<02:34,  3.04s/it][A[A

 50%|█████████████████████                     | 50/100 [02:55<02:18,  2.76s/it][A[A

 51%|█████████████████████▍                    | 51/100 [02:58<02:24,  2.95s/it][A[A

 52%|█████████████████████▊                    | 52/100 [03:01<02:19,  2.90s/it][A[A

 53%|██████████████████████▎  

Iteration : 60 
 , Train Loss : 0.26741280127316713 
, Test Acc : 76.99602508544922 





 62%|██████████████████████████                | 62/100 [03:54<02:42,  4.27s/it][A[A

 63%|██████████████████████████▍               | 63/100 [03:57<02:23,  3.87s/it][A[A

 64%|██████████████████████████▉               | 64/100 [03:59<02:07,  3.55s/it][A[A

 65%|███████████████████████████▎              | 65/100 [04:02<02:00,  3.44s/it][A[A

 66%|███████████████████████████▋              | 66/100 [04:05<01:48,  3.20s/it][A[A

 67%|████████████████████████████▏             | 67/100 [04:09<01:47,  3.25s/it][A[A

 68%|████████████████████████████▌             | 68/100 [04:10<01:30,  2.82s/it][A[A

 69%|████████████████████████████▉             | 69/100 [04:13<01:26,  2.79s/it][A[A

 70%|█████████████████████████████▍            | 70/100 [04:15<01:19,  2.65s/it][A[A

 71%|█████████████████████████████▊            | 71/100 [04:18<01:19,  2.73s/it][A[A

 72%|██████████████████████████████▏           | 72/100 [04:20<01:08,  2.44s/it][A[A

 73%|█████████████████████████

Iteration : 80 
 , Train Loss : 0.2550305149052292 
, Test Acc : 77.4050521850586 





 82%|██████████████████████████████████▍       | 82/100 [04:48<00:51,  2.85s/it][A[A

 83%|██████████████████████████████████▊       | 83/100 [04:51<00:47,  2.79s/it][A[A

 84%|███████████████████████████████████▎      | 84/100 [04:52<00:39,  2.49s/it][A[A

 85%|███████████████████████████████████▋      | 85/100 [04:56<00:41,  2.75s/it][A[A

 86%|████████████████████████████████████      | 86/100 [04:58<00:36,  2.59s/it][A[A

 87%|████████████████████████████████████▌     | 87/100 [05:00<00:33,  2.55s/it][A[A

 88%|████████████████████████████████████▉     | 88/100 [05:03<00:29,  2.48s/it][A[A

 89%|█████████████████████████████████████▍    | 89/100 [05:05<00:27,  2.53s/it][A[A

 90%|█████████████████████████████████████▊    | 90/100 [05:07<00:22,  2.30s/it][A[A

 91%|██████████████████████████████████████▏   | 91/100 [05:13<00:29,  3.28s/it][A[A

 92%|██████████████████████████████████████▋   | 92/100 [05:15<00:22,  2.82s/it][A[A

 93%|█████████████████████████

Iteration : 99 
 , Train Loss : 0.21574696619063616 
, Test Acc : 78.14800262451172 





  0%|                                                   | 0/100 [00:00<?, ?it/s][A[A

  1%|▍                                          | 1/100 [00:03<06:21,  3.85s/it][A[A

Iteration : 0 
 , Train Loss : 0.4339981717057526 
, Test Acc : 38.865013122558594 





  2%|▊                                          | 2/100 [00:05<04:31,  2.77s/it][A[A

  3%|█▎                                         | 3/100 [00:08<04:01,  2.49s/it][A[A

  4%|█▋                                         | 4/100 [00:11<04:24,  2.75s/it][A[A

  5%|██▏                                        | 5/100 [00:13<04:01,  2.54s/it][A[A

  6%|██▌                                        | 6/100 [00:15<03:45,  2.40s/it][A[A

  7%|███                                        | 7/100 [00:18<04:17,  2.76s/it][A[A

  8%|███▍                                       | 8/100 [00:20<03:48,  2.48s/it][A[A

  9%|███▊                                       | 9/100 [00:22<03:34,  2.35s/it][A[A

 10%|████▏                                     | 10/100 [00:26<03:52,  2.59s/it][A[A

 11%|████▌                                     | 11/100 [00:28<03:54,  2.64s/it][A[A

 12%|█████                                     | 12/100 [00:30<03:37,  2.48s/it][A[A

 13%|█████▍                   

Iteration : 20 
 , Train Loss : 0.5303092313697562 
, Test Acc : 72.89054870605469 





 22%|█████████▏                                | 22/100 [00:59<03:56,  3.03s/it][A[A

 23%|█████████▋                                | 23/100 [01:02<03:42,  2.89s/it][A[A

 24%|██████████                                | 24/100 [01:04<03:25,  2.70s/it][A[A

 25%|██████████▌                               | 25/100 [01:08<03:56,  3.15s/it][A[A

 26%|██████████▉                               | 26/100 [01:11<03:40,  2.98s/it][A[A

 27%|███████████▎                              | 27/100 [01:14<03:29,  2.87s/it][A[A

 28%|███████████▊                              | 28/100 [01:17<03:43,  3.11s/it][A[A

 29%|████████████▏                             | 29/100 [01:20<03:29,  2.94s/it][A[A

 30%|████████████▌                             | 30/100 [01:22<03:12,  2.75s/it][A[A

 31%|█████████████                             | 31/100 [01:27<03:45,  3.26s/it][A[A

 32%|█████████████▍                            | 32/100 [01:29<03:21,  2.97s/it][A[A

 33%|█████████████▊           

Iteration : 40 
 , Train Loss : 0.36108798650093377 
, Test Acc : 76.27527618408203 





 42%|█████████████████▋                        | 42/100 [01:57<02:29,  2.57s/it][A[A

 43%|██████████████████                        | 43/100 [02:00<02:44,  2.89s/it][A[A

 44%|██████████████████▍                       | 44/100 [02:02<02:25,  2.59s/it][A[A

 45%|██████████████████▉                       | 45/100 [02:04<02:18,  2.52s/it][A[A

 46%|███████████████████▎                      | 46/100 [02:08<02:36,  2.89s/it][A[A

 47%|███████████████████▋                      | 47/100 [02:11<02:25,  2.75s/it][A[A

 48%|████████████████████▏                     | 48/100 [02:13<02:16,  2.62s/it][A[A

 49%|████████████████████▌                     | 49/100 [02:17<02:36,  3.06s/it][A[A

 50%|█████████████████████                     | 50/100 [02:19<02:18,  2.76s/it][A[A

 51%|█████████████████████▍                    | 51/100 [02:21<02:09,  2.64s/it][A[A

 52%|█████████████████████▊                    | 52/100 [02:24<02:12,  2.75s/it][A[A

 53%|██████████████████████▎  

Iteration : 60 
 , Train Loss : 0.33060413831844926 
, Test Acc : 75.5167236328125 





 62%|██████████████████████████                | 62/100 [02:50<01:39,  2.62s/it][A[A

 63%|██████████████████████████▍               | 63/100 [02:52<01:35,  2.57s/it][A[A

 64%|██████████████████████████▉               | 64/100 [02:56<01:44,  2.89s/it][A[A

 65%|███████████████████████████▎              | 65/100 [02:59<01:37,  2.79s/it][A[A

 66%|███████████████████████████▋              | 66/100 [03:01<01:32,  2.73s/it][A[A

 67%|████████████████████████████▏             | 67/100 [03:05<01:43,  3.15s/it][A[A

 68%|████████████████████████████▌             | 68/100 [03:08<01:32,  2.89s/it][A[A

 69%|████████████████████████████▉             | 69/100 [03:10<01:26,  2.78s/it][A[A

 70%|█████████████████████████████▍            | 70/100 [03:14<01:30,  3.03s/it][A[A

 71%|█████████████████████████████▊            | 71/100 [03:17<01:25,  2.96s/it][A[A

 72%|██████████████████████████████▏           | 72/100 [03:19<01:17,  2.76s/it][A[A

 73%|█████████████████████████

Iteration : 80 
 , Train Loss : 0.2818481170106679 
, Test Acc : 76.810791015625 





 82%|██████████████████████████████████▍       | 82/100 [03:49<00:54,  3.05s/it][A[A

 83%|██████████████████████████████████▊       | 83/100 [03:52<00:48,  2.88s/it][A[A

 84%|███████████████████████████████████▎      | 84/100 [03:54<00:42,  2.69s/it][A[A

 85%|███████████████████████████████████▋      | 85/100 [03:58<00:46,  3.10s/it][A[A

 86%|████████████████████████████████████      | 86/100 [04:01<00:41,  2.93s/it][A[A

 87%|████████████████████████████████████▌     | 87/100 [04:03<00:36,  2.82s/it][A[A

 88%|████████████████████████████████████▉     | 88/100 [04:07<00:36,  3.05s/it][A[A

 89%|█████████████████████████████████████▍    | 89/100 [04:09<00:31,  2.88s/it][A[A

 90%|█████████████████████████████████████▊    | 90/100 [04:11<00:26,  2.70s/it][A[A

 91%|██████████████████████████████████████▏   | 91/100 [04:15<00:27,  3.10s/it][A[A

 92%|██████████████████████████████████████▋   | 92/100 [04:18<00:22,  2.85s/it][A[A

 93%|█████████████████████████

Iteration : 99 
 , Train Loss : 0.21598578715929762 
, Test Acc : 77.58863067626953 






{0: [2, 8, 10, 12, 14, 16, 17, 18, 22, 24, 29, 30],
 1: [0, 1, 3, 4, 5, 6, 7, 9, 11, 13, 15, 19, 20, 21, 23, 25, 26, 27, 28, 31]}

In [None]:
for new in tqdm(range(config["refine_steps"])):
    refine_step = 1
    beta = 0.2
    cluster_trainers = []
    for cluster_id in tqdm(cluster_map.keys()):
        cluster_clients = [client_loaders[i] for i in cluster_map[cluster_id]]
        cluster_trainer = ClusterTrainer(config, os.path.join(config['results_dir'], "refine_{}".format(refine_step), "cluster_{}".format(cluster_id)), cluster_id)
        cluster_trainer.train(cluster_clients)
        cluster_trainers.append(cluster_trainer)
    with open(os.path.join(config["results_dir"],"refine_{}".format(refine_step), "cluster_maps.pkl"), 'wb') as handle:
        pickle.dump(cluster_map, handle, protocol=pickle.HIGHEST_PROTOCOL)
    cluster_map_recluster = {}
    for key in cluster_map.keys():
        cluster_map_recluster[key] = []

    for i in tqdm(range(config["num_clients"])):
        w_node = client_trainers[i].model.state_dict()
        norm_diff = np.infty
        new_cluster_id = 0
        for cluster_id in cluster_map.keys():
            w_cluster = cluster_trainers[cluster_id].model.state_dict()
            curr_norm_diff = model_weights_diff(w_node, w_cluster)
            if norm_diff > curr_norm_diff:
                new_cluster_id = cluster_id
                norm_diff = curr_norm_diff
        
        cluster_map_recluster[new_cluster_id].append(i)
    keys = list(cluster_map_recluster.keys()).copy()
    for key in keys:
        if len(cluster_map_recluster[key]) == 0:
            cluster_map_recluster.pop(key)
    cluster_map = cluster_map_recluster

    
    G = nx.Graph()
    G.add_nodes_from(cluster_map.keys())

    all_pairs = list(itertools.combinations(cluster_map.keys(),2))
    for pair in tqdm(all_pairs):
        w_1  = cluster_trainers[pair[0]].model.state_dict()
        w_2 = cluster_trainers[pair[1]].model.state_dict()
        norm_diff = model_weights_diff(w_1, w_2)
        if norm_diff < thresh:
            G.add_edge(pair[0], pair[1])
    G = G.to_undirected()
    clustering = []        
    correlation_clustering(G)
    merge_clusters = [cluster  for cluster in clustering if len(cluster) > 0]
    
    #merge_cluster_map = {i: clusters[i] for i in range(len(clusters))}
    #clusters = list(nx.algorithms.clique.enumerate_all_cliques(G))
    cluster_map_new = {}
    for i in range(len(merge_clusters)):
        cluster_map_new[i] = []
        for j in merge_clusters[i]:
            cluster_map_new[i] += cluster_map[j]
    cluster_map = cluster_map_new
    test_acc = 0
    for cluster_id in tqdm(cluster_map.keys()):
        cluster_clients = [client_loaders[i] for i in cluster_map[cluster_id]]
        model_wts = [cluster_trainers[j].model.state_dict() for j in merge_clusters[cluster_id]]
        test_acc_cluster, model_avg_wt =avg_acc(model_wts,cluster_clients)
        torch.save(model_avg_wt, os.path.join(config['results_dir'], "refine_{}".format(refine_step), "merged_cluster_{}.pth".format(cluster_id)))
        test_acc += test_acc_cluster
    test_acc = test_acc/len(cluster_map.keys())
    torch.save(test_acc, os.path.join(config['results_dir'], "refine_{}".format(refine_step), "avg_acc.pth"))


In [None]:
global_trainer = GlobalTrainer(config, os.path.join(config["results_dir"], "global"))
global_trainer.train(client_loaders)
