In [1]:
%load_ext autoreload
%autoreload 2

In [2]:




import torch
import numpy as np
from torch.utils.data import Dataset, DataLoader 
from abc import ABC
from tqdm import tqdm
import torchvision
import os
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import random
import json
from collections import OrderedDict
import shutil

In [5]:
shutil.rmtree("./experiments/results/femnist/seed_9998")

In [3]:
torch.cuda.empty_cache()

301

In [None]:
config = {}


In [76]:
config["seed"] = 8
seed = config["seed"]
os.environ['PYTHONHASHSEED'] = str(seed)
# Torch RNG
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
# Python RNG
np.random.seed(seed)
random.seed(seed)


In [71]:
config["participation_ratio"] = 0.05/6
#config["total_num_clients_per_cluster"] = 80
#config["num_clients_per_cluster"] = int(config["participation_ratio"]*config["total_num_clients_per_cluster"])
#config["num_clusters"] = 4
#config["num_clients"] = config["num_clients_per_cluster"]*config["num_clusters"]
config["dataset"] = "femnist"
#DATASET_LIB = {"mnist" : torchvision.datasets.MNIST, "emnist": torchvision.datasets.EMNIST, "cifar10": torchvision.datasets.CIFAR10}
config["dataset_dir"] = "/base_vol/femnist/data"
config["results_dir"] = "./experiments/results"
config["results_dir"] = os.path.join(config["results_dir"], config["dataset"], "seed_{}".format(seed))
os.makedirs(config["results_dir"], exist_ok=True)


In [72]:
from collections import defaultdict
def read_dir(data_dir):
    clients = []
    groups = []
    data = defaultdict(lambda : None)

    files = os.listdir(data_dir)
    files = [f for f in files if f.endswith('.json')]
    for f in files:
        file_path = os.path.join(data_dir,f)
        with open(file_path, 'r') as inf:
            cdata = json.load(inf)
        clients.extend(cdata['users'])
        if 'hierarchies' in cdata:
            groups.extend(cdata['hierarchies'])
        data.update(cdata['user_data'])

    clients = list(sorted(data.keys()))
    return clients, groups, data


def read_data(train_data_dir, test_data_dir):
    '''parses data in given train and test data directories
    assumes:
    - the data in the input directories are .json files with 
        keys 'users' and 'user_data'
    - the set of train set users is the same as the set of test set users
    
    Return:
        clients: list of client ids
        groups: list of group ids; empty list if none found
        train_data: dictionary of train data
        test_data: dictionary of test data
    '''
    train_clients, train_groups, train_data = read_dir(train_data_dir)
    test_clients, test_groups, test_data = read_dir(test_data_dir)

    assert train_clients == test_clients
    assert train_groups == test_groups

    return train_clients, train_groups, train_data, test_data


In [38]:
config["total_clients"], _, train_data, test_data = read_data(os.path.join(config["dataset_dir"],"train"), os.path.join(config["dataset_dir"],"test"))


In [7]:
class ClientDataset(Dataset):
    def __init__(self, data,transforms = None):
        super(ClientDataset,self).__init__()
        self.data = data[0]
        self.labels = data[1]
        self.transforms = transforms

    def __len__(self):
        return len(self.data)
    
    def __getitem__(self,idx):
        idx_data = self.data[idx].reshape(28,28)
        if self.transforms is not None:
            transformed_data =self.transforms(idx_data)
        else:
            transformed_data = idx_data
        idx_labels = self.labels[idx]
        return (transformed_data.float(), idx_labels)

class Client():
    def __init__(self, train_data, test_data, client_id,  train_transforms, test_transforms, train_batch_size, test_batch_size, save_dir):
        self.trainset = ClientDataset(train_data, train_transforms)
        self.testset = ClientDataset(test_data, test_transforms)
        self.trainloader = DataLoader(self.trainset, batch_size = train_batch_size, shuffle=True)
        self.testloader = DataLoader(self.testset, batch_size = test_batch_size, shuffle=False)
        self.train_iterator = iter(self.trainloader)
        self.test_iterator = iter(self.testloader)
        self.client_id = client_id
        self.save_dir = os.path.join(save_dir, "init", "client_{}".format(client_id))

    def sample_batch(self, train=True):
        iterator = self.train_iterator if train else self.test_iterator
        try:
            (data, labels) = next(iterator)
        except StopIteration:
            if train:
                self.train_iterator = iter(self.trainloader)
                iterator = self.train_iterator
            else:
                self.test_iterator = iter(self.testloader)
                iterator = self.test_iterator
            (data, labels) = next(iterator)
        return (data, labels)


In [39]:
config["train_batch"] = 64
config["test_batch"] = 512
config["num_clients"] = int(len(config["total_clients"])/6)


In [77]:
## Generate new clients
selected_clients_path = os.path.join(config["results_dir"], "selected_clients.json")
if os.path.exists(selected_clients_path):
    with open(selected_clients_path, "r") as f:
        config["selected_clients"] = json.load(f)
else:
    config["selected_clients"] = random.sample(config["total_clients"], config["num_clients"])
    with open(selected_clients_path, "w") as f:
        json.dump(config["selected_clients"], f)

In [78]:
client_loaders = []
for client_id in config["selected_clients"]:
        client_loaders.append(
            Client(
                (np.array(train_data[client_id]['x']), np.array(train_data[client_id]['y'])),
                (np.array(test_data[client_id]['x']), np.array(test_data[client_id]['y'])),
                client_id,
                train_transforms=torchvision.transforms.ToTensor(),
                test_transforms=torchvision.transforms.ToTensor(),
                train_batch_size=config["train_batch"],
                test_batch_size=config["test_batch"],
                save_dir=config["results_dir"],
            )
        )


In [42]:
class SimpleCNN(torch.nn.Module):

    def __init__(self, h1=2048):
        super().__init__()
        self.conv1 = torch.nn.Conv2d(1, 32, kernel_size = (5,5), padding="same")
        self.pool1 = torch.nn.MaxPool2d((2,2), stride=2)
        self.conv2 = torch.nn.Conv2d(32, 64, kernel_size= (5,5), padding = "same")
        self.pool2 = torch.nn.MaxPool2d((2,2), stride=2)
        self.fc1 = torch.nn.Linear(64*7*7, 2048)
        self.fc2 = torch.nn.Linear(2048, 62)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.pool1(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = self.pool2(x)
        x = x.flatten(start_dim=1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x


In [12]:
def set_weights(model):
    model_wt = torch.load('/base_vol/model_wt_dict.pt')
    new_wts = OrderedDict()
    new_wts['fc2.weight'] = torch.Tensor(model_wt["dense_1/kernel"]).t()
    new_wts['fc2.bias'] = torch.Tensor(model_wt["dense_1/bias"])
    new_wts['fc1.weight'] = torch.Tensor(model_wt["dense/kernel"]).t()
    new_wts['fc1.bias'] = torch.Tensor(model_wt["dense/bias"])
    new_wts["conv1.weight"] = torch.Tensor(model_wt["conv2d/kernel"]).permute(3,2,0,1)
    new_wts["conv2.weight"] = torch.Tensor(model_wt["conv2d_1/kernel"]).permute(3,2,0,1)
    new_wts["conv1.bias"] = torch.Tensor(model_wt["conv2d/bias"])
    new_wts["conv2.bias"] = torch.Tensor(model_wt["conv2d_1/bias"])
    model.load_state_dict(new_wts)
    return freeze_layers(model)
def freeze_layers(model):
    model.conv1.weight.requires_grad =False
    model.conv2.weight.requires_grad =False
    model.fc1.weight.requires_grad =True
    model.fc2.weight.requires_grad =True
    model.conv1.bias.requires_grad =False
    model.conv2.bias.requires_grad =False
    model.fc1.bias.requires_grad =True
    model.fc2.bias.requires_grad =True
    return model

In [13]:
def calc_acc(model, device, client_data, train):
    loader = client_data.trainloader if train else client_data.testloader
    model.eval()
    model.to(device)
    num_corr = 0
    tot_num = 0
    with torch.no_grad():
        for (X,Y) in loader:
            X = X.to(device)
            pred = model(X).argmax(axis=1).detach().cpu()
            num_corr += (Y == pred).float().sum()
            tot_num += Y.shape[0]
    acc = num_corr/tot_num
    acc *= 100.0
    return acc


In [14]:
class BaseTrainer(ABC):
    def __init__(self,config, save_dir):
        super(BaseTrainer, self).__init__()
        self.model = set_weights(MODEL_LIST[config["model"]](**config["model_params"]))
        self.save_dir = save_dir
        self.device = config["device"]
        self.loss_func = LOSSES[config["loss_func"]]
        self.config = config
        os.makedirs(self.save_dir, exist_ok=True)

    def train(self):
        raise NotImplementedError
    
    def test(self):
        raise NotImplementedError

    def load_model_weights(self):
        model_path  = os.path.join(self.save_dir, "model.pth")
        if os.path.exists(model_path):
            self.model.load_state_dict(torch.load(model_path))
        else:
            print("No model present at path : {}".format())

    def save_model_weights(self):
        model_path  = os.path.join(self.save_dir, "model.pth")
        torch.save(self.model.state_dict(), model_path)
    def save_metrics(self, train_loss, test_acc, iteration):
        torch.save({"train_loss": train_loss,  "test_acc" : test_acc}, os.path.join(self.save_dir,"metrics_{}.pkl".format(iteration)))

class ClientTrainer(BaseTrainer):
    def __init__(self,  config, save_dir,client_id):
        super(ClientTrainer, self).__init__(config, save_dir)
        self.client_id = client_id
    
    def train(self, client_data):
        train_loss_list = []
        test_acc_list = []
        self.model.to(self.device)
        self.model.train()
        optimizer = OPTIMIZER_LIST[self.config["optimizer"]](self.model.parameters(), **self.config["optimizer_params"])
        for iteration in range(self.config["iterations"]):
            self.model.zero_grad()
            (X,Y) = client_data.sample_batch(train=True)
            X = X.to(self.device)
            Y = Y.to(self.device)
            out = self.model(X)
            loss = self.loss_func(out, Y)
            loss.backward()
            optimizer.step()
            train_loss = loss.detach().cpu().numpy().item()
            train_loss_list.append(train_loss)
            test_acc = calc_acc(self.model, self.device, client_data, train=False)
            test_acc_list.append(test_acc)
            self.model.train()
            if iteration % self.config["save_freq"] == 0 or iteration == self.config["iterations"] - 1:
                self.save_model_weights()
                self.save_metrics(train_loss_list, test_acc_list, iteration)
            if iteration % self.config["print_freq"] == 0 or iteration == self.config["iterations"] - 1:
                print("Iteration : {} \n , Train Loss : {} \n, Test Acc : {} \n".format(iteration,  train_loss, test_acc))
                
        self.model.eval()
        self.model.cpu()


    def test(self, client_data):
        self.load_model_weights()
        self.model.eval()
        self.model.to(self.device)
        acc =  calc_acc(self.model, client_data)
        self.model.cpu()
        return acc




In [43]:
  
MODEL_LIST = {"cnn" : SimpleCNN}
OPTIMIZER_LIST = {"sgd": optim.SGD, "adam": optim.Adam}
LOSSES = {"cross_entropy": nn.CrossEntropyLoss()}
# config["save_dir"] = os.path.join("./results")
config["iterations"] = 100
config["optimizer_params"] = {"lr":0.001}
config["save_freq"] = 2
config["print_freq"]  = 50
config["model"] = "cnn"
config["optimizer"] = "adam"
config["loss_func"] = "cross_entropy"
#config["model_params"] = {"num_channels": 1 , "num_classes"  : 62}
config["model_params"] = {}
config["device"] = torch.device("cuda:0")
import pickle






In [46]:
client_trainers = [ClientTrainer(config,os.path.join(config["results_dir"], "init", client_id), client_id) for client_id in config["selected_clients"]]


In [47]:
for i in tqdm(range(len(config["selected_clients"]))):
    client_trainers[i].train(client_loaders[i])


  0%|                                                                         | 0/32 [00:00<?, ?it/s]

Iteration : 0 
 , Train Loss : 6.637035846710205 
, Test Acc : 3.7037036418914795 

Iteration : 50 
 , Train Loss : 0.29978856444358826 
, Test Acc : 85.18518829345703 



  3%|██                                                               | 1/32 [00:11<05:53, 11.40s/it]

Iteration : 99 
 , Train Loss : 0.03732234239578247 
, Test Acc : 85.18518829345703 

Iteration : 0 
 , Train Loss : 6.662721157073975 
, Test Acc : 5.263157844543457 

Iteration : 50 
 , Train Loss : 0.2584386467933655 
, Test Acc : 81.57894897460938 



  6%|████                                                             | 2/32 [00:21<05:26, 10.90s/it]

Iteration : 99 
 , Train Loss : 0.09703311324119568 
, Test Acc : 92.10526275634766 

Iteration : 0 
 , Train Loss : 8.357836723327637 
, Test Acc : 27.77777862548828 

Iteration : 50 
 , Train Loss : 0.012913631275296211 
, Test Acc : 66.66667175292969 



  9%|██████                                                           | 3/32 [00:34<05:35, 11.57s/it]

Iteration : 99 
 , Train Loss : 0.002325940178707242 
, Test Acc : 66.66667175292969 

Iteration : 0 
 , Train Loss : 7.451553821563721 
, Test Acc : 11.111111640930176 

Iteration : 50 
 , Train Loss : 0.0485992431640625 
, Test Acc : 55.55555725097656 



 12%|████████▏                                                        | 4/32 [01:00<08:07, 17.41s/it]

Iteration : 99 
 , Train Loss : 0.02243007719516754 
, Test Acc : 55.55555725097656 

Iteration : 0 
 , Train Loss : 8.230623245239258 
, Test Acc : 8.333333969116211 

Iteration : 50 
 , Train Loss : 0.027989299967885017 
, Test Acc : 33.333335876464844 



 16%|██████████▏                                                      | 5/32 [01:11<06:48, 15.15s/it]

Iteration : 99 
 , Train Loss : 0.011420486494898796 
, Test Acc : 33.333335876464844 

Iteration : 0 
 , Train Loss : 8.040980339050293 
, Test Acc : 9.090909004211426 

Iteration : 50 
 , Train Loss : 0.03169284760951996 
, Test Acc : 36.3636360168457 



 19%|████████████▏                                                    | 6/32 [01:22<05:52, 13.57s/it]

Iteration : 99 
 , Train Loss : 0.03318779170513153 
, Test Acc : 36.3636360168457 

Iteration : 0 
 , Train Loss : 6.923524856567383 
, Test Acc : 35.29411697387695 

Iteration : 50 
 , Train Loss : 0.0794343501329422 
, Test Acc : 70.5882339477539 



 22%|██████████████▏                                                  | 7/32 [01:32<05:13, 12.54s/it]

Iteration : 99 
 , Train Loss : 0.013054211623966694 
, Test Acc : 70.5882339477539 

Iteration : 0 
 , Train Loss : 5.4641242027282715 
, Test Acc : 7.692307949066162 

Iteration : 50 
 , Train Loss : 0.4184598922729492 
, Test Acc : 71.79487609863281 



 25%|████████████████▎                                                | 8/32 [01:40<04:22, 10.95s/it]

Iteration : 99 
 , Train Loss : 0.16541633009910583 
, Test Acc : 76.92308044433594 

Iteration : 0 
 , Train Loss : 6.554172992706299 
, Test Acc : 11.764705657958984 

Iteration : 50 
 , Train Loss : 0.07564735412597656 
, Test Acc : 64.70588684082031 



 28%|██████████████████▎                                              | 9/32 [01:59<05:12, 13.60s/it]

Iteration : 99 
 , Train Loss : 0.015431844629347324 
, Test Acc : 70.5882339477539 

Iteration : 0 
 , Train Loss : 6.708613872528076 
, Test Acc : 6.45161247253418 

Iteration : 50 
 , Train Loss : 0.18339774012565613 
, Test Acc : 67.74193572998047 



 31%|████████████████████                                            | 10/32 [02:23<06:09, 16.79s/it]

Iteration : 99 
 , Train Loss : 0.05663880333304405 
, Test Acc : 61.29032516479492 

Iteration : 0 
 , Train Loss : 8.799028396606445 
, Test Acc : 18.75 

Iteration : 50 
 , Train Loss : 0.2048787772655487 
, Test Acc : 62.5 



 34%|██████████████████████                                          | 11/32 [02:32<05:01, 14.37s/it]

Iteration : 99 
 , Train Loss : 0.04977525770664215 
, Test Acc : 62.5 

Iteration : 0 
 , Train Loss : 5.952098846435547 
, Test Acc : 0.0 

Iteration : 50 
 , Train Loss : 0.0030323301907628775 
, Test Acc : 0.0 



 38%|████████████████████████                                        | 12/32 [03:02<06:20, 19.03s/it]

Iteration : 99 
 , Train Loss : 0.0009455641265958548 
, Test Acc : 0.0 

Iteration : 0 
 , Train Loss : 6.376664638519287 
, Test Acc : 12.5 

Iteration : 50 
 , Train Loss : 0.17541885375976562 
, Test Acc : 50.0 



 41%|██████████████████████████                                      | 13/32 [03:09<04:51, 15.36s/it]

Iteration : 99 
 , Train Loss : 0.08128704130649567 
, Test Acc : 62.5 

Iteration : 0 
 , Train Loss : 6.298201084136963 
, Test Acc : 3.0303030014038086 

Iteration : 50 
 , Train Loss : 0.18745236098766327 
, Test Acc : 69.69696807861328 



 44%|████████████████████████████                                    | 14/32 [03:19<04:07, 13.77s/it]

Iteration : 99 
 , Train Loss : 0.11385001987218857 
, Test Acc : 75.75757598876953 

Iteration : 0 
 , Train Loss : 6.094166278839111 
, Test Acc : 9.375 

Iteration : 50 
 , Train Loss : 0.29447534680366516 
, Test Acc : 59.375 



 47%|██████████████████████████████                                  | 15/32 [03:50<05:25, 19.13s/it]

Iteration : 99 
 , Train Loss : 0.09123311936855316 
, Test Acc : 68.75 

Iteration : 0 
 , Train Loss : 6.213800430297852 
, Test Acc : 0.0 

Iteration : 50 
 , Train Loss : 0.31023669242858887 
, Test Acc : 85.36585235595703 



 50%|████████████████████████████████                                | 16/32 [04:00<04:22, 16.39s/it]

Iteration : 99 
 , Train Loss : 0.13614115118980408 
, Test Acc : 85.36585235595703 

Iteration : 0 
 , Train Loss : 6.590491771697998 
, Test Acc : 6.25 

Iteration : 50 
 , Train Loss : 0.29671910405158997 
, Test Acc : 78.125 



 53%|██████████████████████████████████                              | 17/32 [04:39<05:48, 23.23s/it]

Iteration : 99 
 , Train Loss : 0.25083592534065247 
, Test Acc : 84.375 

Iteration : 0 
 , Train Loss : 8.519227027893066 
, Test Acc : 12.5 

Iteration : 50 
 , Train Loss : 0.012801522389054298 
, Test Acc : 62.5 



 56%|████████████████████████████████████                            | 18/32 [04:59<05:08, 22.03s/it]

Iteration : 99 
 , Train Loss : 0.0015451579820364714 
, Test Acc : 68.75 

Iteration : 0 
 , Train Loss : 7.674419403076172 
, Test Acc : 3.7037036418914795 

Iteration : 50 
 , Train Loss : 0.09918440133333206 
, Test Acc : 55.55555725097656 



 59%|██████████████████████████████████████                          | 19/32 [05:25<05:05, 23.47s/it]

Iteration : 99 
 , Train Loss : 0.08216364681720734 
, Test Acc : 55.55555725097656 

Iteration : 0 
 , Train Loss : 5.907440185546875 
, Test Acc : 8.108107566833496 

Iteration : 50 
 , Train Loss : 0.2748275399208069 
, Test Acc : 78.37837982177734 



 62%|████████████████████████████████████████                        | 20/32 [06:34<07:22, 36.91s/it]

Iteration : 99 
 , Train Loss : 0.08560672402381897 
, Test Acc : 67.56756591796875 

Iteration : 0 
 , Train Loss : 7.155414581298828 
, Test Acc : 46.66666793823242 

Iteration : 50 
 , Train Loss : 2.5978810787200928 
, Test Acc : 86.66666412353516 



 66%|██████████████████████████████████████████                      | 21/32 [07:00<06:11, 33.80s/it]

Iteration : 99 
 , Train Loss : 0.02516932785511017 
, Test Acc : 93.33333587646484 

Iteration : 0 
 , Train Loss : 8.825970649719238 
, Test Acc : 16.666667938232422 

Iteration : 50 
 , Train Loss : 0.013488389551639557 
, Test Acc : 88.8888931274414 



 69%|████████████████████████████████████████████                    | 22/32 [07:15<04:40, 28.07s/it]

Iteration : 99 
 , Train Loss : 0.0031321528367698193 
, Test Acc : 88.8888931274414 

Iteration : 0 
 , Train Loss : 6.891552448272705 
, Test Acc : 6.6666669845581055 

Iteration : 50 
 , Train Loss : 0.05447373911738396 
, Test Acc : 80.0 



 72%|██████████████████████████████████████████████                  | 23/32 [07:35<03:51, 25.71s/it]

Iteration : 99 
 , Train Loss : 0.006795903202146292 
, Test Acc : 80.0 

Iteration : 0 
 , Train Loss : 6.029994964599609 
, Test Acc : 7.317072868347168 

Iteration : 50 
 , Train Loss : 0.09322480857372284 
, Test Acc : 78.04877471923828 



 75%|████████████████████████████████████████████████                | 24/32 [07:42<02:41, 20.17s/it]

Iteration : 99 
 , Train Loss : 0.04787515476346016 
, Test Acc : 78.04877471923828 

Iteration : 0 
 , Train Loss : 5.696196556091309 
, Test Acc : 14.285715103149414 

Iteration : 50 
 , Train Loss : 0.1754637509584427 
, Test Acc : 82.85714721679688 



 78%|██████████████████████████████████████████████████              | 25/32 [07:53<02:00, 17.16s/it]

Iteration : 99 
 , Train Loss : 0.09427152574062347 
, Test Acc : 80.0 

Iteration : 0 
 , Train Loss : 6.438925266265869 
, Test Acc : 9.756096839904785 

Iteration : 50 
 , Train Loss : 0.11693454533815384 
, Test Acc : 80.48780822753906 



 81%|████████████████████████████████████████████████████            | 26/32 [08:25<02:09, 21.66s/it]

Iteration : 99 
 , Train Loss : 0.031470321118831635 
, Test Acc : 80.48780822753906 

Iteration : 0 
 , Train Loss : 9.238253593444824 
, Test Acc : 5.882352828979492 

Iteration : 50 
 , Train Loss : 0.08145651966333389 
, Test Acc : 58.82353210449219 



 84%|██████████████████████████████████████████████████████          | 27/32 [08:33<01:27, 17.51s/it]

Iteration : 99 
 , Train Loss : 0.05059973523020744 
, Test Acc : 70.5882339477539 

Iteration : 0 
 , Train Loss : 8.057792663574219 
, Test Acc : 6.25 

Iteration : 50 
 , Train Loss : 0.018014831468462944 
, Test Acc : 87.5 



 88%|████████████████████████████████████████████████████████        | 28/32 [08:48<01:07, 16.89s/it]

Iteration : 99 
 , Train Loss : 0.028324808925390244 
, Test Acc : 87.5 

Iteration : 0 
 , Train Loss : 5.508042335510254 
, Test Acc : 13.88888931274414 

Iteration : 50 
 , Train Loss : 0.13821065425872803 
, Test Acc : 75.0 



 91%|██████████████████████████████████████████████████████████      | 29/32 [09:05<00:51, 17.00s/it]

Iteration : 99 
 , Train Loss : 0.1260848492383957 
, Test Acc : 75.0 

Iteration : 0 
 , Train Loss : 7.068293571472168 
, Test Acc : 6.25 



 91%|██████████████████████████████████████████████████████████      | 29/32 [09:09<00:56, 18.94s/it]


KeyboardInterrupt: 

In [21]:
## Load model weights
for trainer in client_trainers:
    trainer.load_model_weights()


In [22]:
local_acc = 0
for i in range(len(client_loaders)):
    local_acc += calc_acc(client_trainers[i].model, config["device"], client_loaders[i], train=False)
local_acc = local_acc/len(client_loaders)
print(f"Local Accuracy : {local_acc}")

Local Accuracy : 67.0177001953125


In [2]:
prin

NameError: name 'client_trainers' is not defined

In [18]:
a = 0
for i, trainer in enumerate(client_trainers):
    a+= calc_acc(trainer.model, config["device"], client_loaders[i], train=False)

In [19]:
print(a/len(client_trainers))

tensor(66.1461)


In [59]:
# config["iterations"] = 100
# config["optimizer_params"] = {"lr":0.001}
# config["save_freq"] = 2
# config["print_freq"]  = 40
# config["model"] = "cnn"
# config["optimizer"] = "adam"
# config["loss_func"] = "cross_entropy"
# #config["model_params"] = {"num_channels": 1 , "num_classes"  : 62}
# config["model_params"] = {}
# config["device"] = torch.device("cuda:0")

# client_id = config["selected_clients"][0]
# client_trainers[0] = ClientTrainer(config,os.path.join(config["results_dir"], "init", client_id), client_id)
# client_trainers[0].train(client_loaders[0])
import networkx as nx



In [None]:
import networkx as nx
G = nx.Graph()
G.add_nodes_from(range(config["num_clients"]))
import itertools
def model_weights_diff(w_1, w_2):
    norm_sq = 0
    assert w_1.keys() == w_2.keys(), "Model weights have different keys"
    for key in w_1.keys():
        norm_sq  += (w_1[key].cpu() - w_2[key].cpu()).norm()**2
    return np.sqrt(norm_sq)
wt = client_trainers[0].model.state_dict()


In [50]:
all_pairs = list(itertools.combinations(range(config["num_clients"]),2))
arr = []
for pair in all_pairs:
    w_1  = client_trainers[pair[0]].model.state_dict()
    w_2 = client_trainers[pair[1]].model.state_dict()
    norm_diff = model_weights_diff(w_1, w_2)
    arr.append(norm_diff)
#thresh = torch.mean(torch.tensor(arr))


In [51]:
thresh = arr[torch.tensor(arr).argsort()[int(0.3*len(arr))-1]]


In [52]:
clustering = []
def correlation_clustering(G):
    global clustering
    if len(G.nodes) == 0:
        return
    else:
        cluster = []
        new_cluster_pivot = random.sample(G.nodes,1)[0]
        cluster.append(new_cluster_pivot)
        neighbors = G[new_cluster_pivot].copy()
        for node in neighbors:
            cluster.append(node)
            G.remove_node(node)
        G.remove_node(new_cluster_pivot)
        clustering.append(cluster)
        correlation_clustering(G)


In [53]:
while True:
    G = nx.Graph()
    G.add_nodes_from(range(config["num_clients"]))
    for i in range(len(all_pairs)):
        if arr[i] < thresh:
            G.add_edge(all_pairs[i][0], all_pairs[i][1])
    G = G.to_undirected()
    clustering = []
    correlation_clustering(G)
    clusters = [cluster  for cluster in clustering if len(cluster) > 1 ]
    print(len(clusters))
    if len(clusters) >= 2:
        break
    
cluster_map = {i: clusters[i] for i in range(len(clusters))}
beta = 0.2

5


since Python 3.9 and will be removed in a subsequent version.
  new_cluster_pivot = random.sample(G.nodes,1)[0]


In [54]:
os.makedirs(os.path.join(config["results_dir"],"refine_0"),exist_ok=True)
with open(os.path.join(config["results_dir"],"refine_0", "cluster_maps.pkl"), 'wb') as handle:
    pickle.dump(cluster_map, handle, protocol=pickle.HIGHEST_PROTOCOL)

In [56]:
cluster_map

{0: [8, 2, 3, 21, 22, 27, 29, 30, 31],
 1: [0, 15, 24, 25],
 2: [6, 10, 17],
 3: [26, 5],
 4: [16, 13]}

In [31]:
class ClusterTrainer(BaseTrainer):
    def __init__(self,  config, save_dir,cluster_id):
        super(ClusterTrainer, self).__init__(config, save_dir)
        self.cluster_id = cluster_id
    
    def train(self, client_data_list):
        num_clients = len(client_data_list)

        train_loss_list = []
        test_acc_list = []
        self.model.to(self.device)
        self.model.train()
        
        
        optimizer = OPTIMIZER_LIST[self.config["optimizer"]](self.model.parameters(), **self.config["optimizer_params"])
        #eff_num_workers = int(num_clients/(1 - 2*beta))
        # if eff_num_workers > 0:
        #     eff_batch_size = self.config["train_batch"]/eff_num_workers
        #     for i in range(num_clients):
        #         client_data_list[i].trainloader.batch_size = eff_batch_size
                
        for iteration in tqdm(range(self.config["iterations"])):
            trmean_buffer = {}
            for idx, param in self.model.named_parameters():
                if param.requires_grad:
                    trmean_buffer[idx] = []
            train_loss = 0
            #optimizer.zero_grad(set_to_none=True)

            for client in client_data_list:
                #if eff_num_workers>0:
                optimizer.zero_grad(set_to_none=True)
                (X,Y) = client.sample_batch()
                X = X.to(config["device"])
                Y = Y.to(config["device"])
                loss_func = nn.CrossEntropyLoss()
                out = self.model(X)
                loss = loss_func(out,Y)
                loss.backward()
                train_loss += loss.detach().cpu().numpy().item()
                
                with torch.no_grad():
                    for idx, param in self.model.named_parameters():
                        if param.requires_grad:
                            trmean_buffer[idx].append(param.grad.clone())
            train_loss = train_loss/num_clients
            optimizer.zero_grad()
            
            start_idx = int(beta*num_clients)
            end_idx = int((1-beta)*num_clients)
            if end_idx <= start_idx + 1:
                start_idx = 0
                end_idx = num_clients


            for idx, param in self.model.named_parameters():
                if param.requires_grad:
                    sorted, _  = torch.sort(torch.stack(trmean_buffer[idx], dim=0), dim=0)
                    new_grad = sorted[start_idx:end_idx,...].mean(dim=0)
                    param.grad = new_grad
                    trmean_buffer[idx] = []
            optimizer.step()
            
            train_loss_list.append(train_loss)
            test_acc = 0
            for client_data in client_data_list:
                test_acc += calc_acc(self.model, self.device, client_data, train=False)
            test_acc = test_acc/num_clients
            test_acc_list.append(test_acc)
            self.model.train()
            if iteration % self.config["save_freq"] == 0 or iteration == self.config["iterations"] - 1:
                self.save_model_weights()
                self.save_metrics(train_loss_list, test_acc_list, iteration)
            if iteration % self.config["print_freq"] == 0 or iteration == self.config["iterations"] - 1:
                print("Iteration : {} \n , Train Loss : {} \n, Test Acc : {} \n".format(iteration,  train_loss, test_acc))
                
        self.model.eval()
        self.model.cpu()


    def test(self, client_data_list):
        self.load_model_weights()
        self.model.eval()
        self.model.to(self.device)
        test_acc = 0
        for client_data in client_data_list:
            test_acc += calc_acc(self.model, self.device, client_data, train=False)
        test_acc = test_acc/len(client_data_list)
        self.model.cpu()
        return test_acc


def avg_acc(model_wts, client_data_list):
    orig = model_wts[0]
    if len(model_wts) > 0:
        for wt in model_wts[1:]:
            for key in orig.keys():
                if orig[key].dtype == torch.float32:
                    orig[key] += wt[key] 
        for key in orig.keys():
            if orig[key].dtype == torch.float32:
                orig[key] = orig[key]/len(model_wts)
    model = SimpleCNN()
    model.load_state_dict(orig)
    model.to(memory_format = torch.channels_last).cuda()
    test_acc = 0
    for client_data in client_data_list:
        test_acc += calc_acc(model, torch.device("cuda:0"), client_data, train=False)
    test_acc = test_acc/len(client_data_list)
    return test_acc, orig


In [55]:
config["refine_steps"] = 2

In [None]:
import re

refine_step = 0
refine_path = os.path.join(config["results_dir"], f"refine_{refine_step}")
num_clusters = [x for x in os.listdir(refine_path) if re.match("cluster_\d", x) is not None]

with open(os.path.join(config["results_dir"], "refine_1", "cluster_maps.pkl"), "rb") as handle:
    cluster_map = pickle.load(handle)

In [None]:
import re
if re.match("cluster_\d", "cluster_1") is not None:
    print("here")
os.listdir(os.path.join(refine_path, "cluster_0", "model.pth"))

In [60]:
with open(os.path.join(config["results_dir"],"refine_1", "cluster_maps.pkl"), 'rb') as handle:
    cluster_map = pickle.load(handle)


In [64]:
(len(cluster_map[0]) *80.82542419433594 + len(cluster_map[1])*68.18305969238281)/(len(cluster_map[0]) + len(cluster_map[1]))

72.13379859924316

In [57]:
config["refine_steps"] = 2
for refine_step in tqdm(range(config["refine_steps"])):
    if os.path.exists(os.path.join(config["results_dir"],"refine_{}".format(refine_step), "cluster_maps.pkl")):
        shutil.rmtree(os.path.join(config["results_dir"],"refine_{}".format(refine_step)))
    beta = 0.15
    cluster_trainers = []
    
    for cluster_id in tqdm(cluster_map.keys()):
        cluster_clients = [client_loaders[i] for i in cluster_map[cluster_id]]
        cluster_trainer = ClusterTrainer(config, os.path.join(config['results_dir'], "refine_{}".format(refine_step), "cluster_{}".format(cluster_id)), cluster_id)
        cluster_trainer.train(cluster_clients)
        cluster_trainers.append(cluster_trainer)
    with open(os.path.join(config["results_dir"],"refine_{}".format(refine_step), "cluster_maps.pkl"), 'wb') as handle:
        pickle.dump(cluster_map, handle, protocol=pickle.HIGHEST_PROTOCOL)
    cluster_map_recluster = {}
    for key in cluster_map.keys():
        cluster_map_recluster[key] = []

    for i in tqdm(range(config["num_clients"])):
        w_node = client_trainers[i].model.state_dict()
        norm_diff = np.infty
        new_cluster_id = 0
        for cluster_id in cluster_map.keys():
            w_cluster = cluster_trainers[cluster_id].model.state_dict()
            curr_norm_diff = model_weights_diff(w_node, w_cluster)
            if norm_diff > curr_norm_diff:
                new_cluster_id = cluster_id
                norm_diff = curr_norm_diff
        
        cluster_map_recluster[new_cluster_id].append(i)
    keys = list(cluster_map_recluster.keys()).copy()
    for key in keys:
        if len(cluster_map_recluster[key]) == 0:
            cluster_map_recluster.pop(key)
    cluster_map = cluster_map_recluster

    
    G = nx.Graph()
    G.add_nodes_from(cluster_map.keys())

    all_pairs = list(itertools.combinations(cluster_map.keys(),2))
    for pair in tqdm(all_pairs):
        w_1  = cluster_trainers[pair[0]].model.state_dict()
        w_2 = cluster_trainers[pair[1]].model.state_dict()
        norm_diff = model_weights_diff(w_1, w_2)
        if norm_diff < thresh:
            G.add_edge(pair[0], pair[1])
    G = G.to_undirected()
    clustering = []        
    correlation_clustering(G)
    merge_clusters = [cluster  for cluster in clustering if len(cluster) > 0]
    
    #merge_cluster_map = {i: clusters[i] for i in range(len(clusters))}
    #clusters = list(nx.algorithms.clique.enumerate_all_cliques(G))
    cluster_map_new = {}
    for i in range(len(merge_clusters)):
        cluster_map_new[i] = []
        for j in merge_clusters[i]:
            cluster_map_new[i] += cluster_map[j]
    cluster_map = cluster_map_new
    test_acc = 0
    for cluster_id in tqdm(cluster_map.keys()):
        cluster_clients = [client_loaders[i] for i in cluster_map[cluster_id]]
        model_wts = [cluster_trainers[j].model.state_dict() for j in merge_clusters[cluster_id]]
        test_acc_cluster, model_avg_wt =avg_acc(model_wts,cluster_clients)
        torch.save(model_avg_wt, os.path.join(config['results_dir'], "refine_{}".format(refine_step), "merged_cluster_{}.pth".format(cluster_id)))
        test_acc += test_acc_cluster
    test_acc = test_acc/len(cluster_map.keys())
    torch.save(test_acc, os.path.join(config['results_dir'], "refine_{}".format(refine_step), "avg_acc.pth"))


  0%|                                                                          | 0/2 [00:00<?, ?it/s]
  0%|                                                                          | 0/5 [00:00<?, ?it/s][A

  0%|                                                                        | 0/100 [00:00<?, ?it/s][A[A

  1%|▋                                                               | 1/100 [00:00<01:26,  1.14it/s][A[A

  2%|█▎                                                              | 2/100 [00:00<00:41,  2.33it/s][A[A

Iteration : 0 
 , Train Loss : 7.13593504163954 
, Test Acc : 14.167271614074707 





  3%|█▉                                                              | 3/100 [00:02<01:18,  1.23it/s][A[A

  5%|███▏                                                            | 5/100 [00:02<00:41,  2.29it/s][A[A

  6%|███▊                                                            | 6/100 [00:02<00:32,  2.90it/s][A[A

  7%|████▍                                                           | 7/100 [00:02<00:27,  3.40it/s][A[A

  8%|█████                                                           | 8/100 [00:02<00:21,  4.22it/s][A[A

  9%|█████▊                                                          | 9/100 [00:03<00:20,  4.40it/s][A[A

 11%|██████▉                                                        | 11/100 [00:03<00:16,  5.43it/s][A[A

 13%|████████▏                                                      | 13/100 [00:03<00:15,  5.77it/s][A[A

 14%|████████▊                                                      | 14/100 [00:03<00:13,  6.33it/s][A[A

 15%|█████████▍  

Iteration : 50 
 , Train Loss : 0.5275466458665 
, Test Acc : 65.75405883789062 





 53%|█████████████████████████████████▍                             | 53/100 [00:09<00:06,  7.04it/s][A[A

 55%|██████████████████████████████████▋                            | 55/100 [00:09<00:06,  7.15it/s][A[A

 57%|███████████████████████████████████▉                           | 57/100 [00:09<00:05,  7.22it/s][A[A

 59%|█████████████████████████████████████▏                         | 59/100 [00:10<00:06,  5.94it/s][A[A

 61%|██████████████████████████████████████▍                        | 61/100 [00:11<00:08,  4.65it/s][A[A

 63%|███████████████████████████████████████▋                       | 63/100 [00:11<00:08,  4.40it/s][A[A

 65%|████████████████████████████████████████▉                      | 65/100 [00:12<00:08,  4.36it/s][A[A

 66%|█████████████████████████████████████████▌                     | 66/100 [00:12<00:07,  4.84it/s][A[A

 67%|██████████████████████████████████████████▏                    | 67/100 [00:13<00:14,  2.26it/s][A[A

 68%|████████████

Iteration : 99 
 , Train Loss : 1.3903063767486148 
, Test Acc : 60.66206741333008 





  0%|                                                                        | 0/100 [00:00<?, ?it/s][A[A

  1%|▋                                                               | 1/100 [00:00<00:18,  5.43it/s][A[A

Iteration : 0 
 , Train Loss : 6.390604615211487 
, Test Acc : 7.275132179260254 





  3%|█▉                                                              | 3/100 [00:00<00:12,  8.05it/s][A[A

  5%|███▏                                                            | 5/100 [00:00<00:15,  6.28it/s][A[A

  7%|████▍                                                           | 7/100 [00:01<00:24,  3.79it/s][A[A

  9%|█████▊                                                          | 9/100 [00:03<00:48,  1.89it/s][A[A

 11%|██████▉                                                        | 11/100 [00:03<00:33,  2.66it/s][A[A

 13%|████████▏                                                      | 13/100 [00:03<00:24,  3.55it/s][A[A

 15%|█████████▍                                                     | 15/100 [00:04<00:18,  4.48it/s][A[A

 17%|██████████▋                                                    | 17/100 [00:04<00:16,  5.18it/s][A[A

 19%|███████████▉                                                   | 19/100 [00:04<00:13,  6.20it/s][A[A

 21%|████████████

Iteration : 50 
 , Train Loss : 1.4060882776975632 
, Test Acc : 61.544715881347656 





 55%|██████████████████████████████████▋                            | 55/100 [00:08<00:05,  8.64it/s][A[A

 57%|███████████████████████████████████▉                           | 57/100 [00:09<00:04,  8.67it/s][A[A

 59%|█████████████████████████████████████▏                         | 59/100 [00:09<00:04,  8.89it/s][A[A

 61%|██████████████████████████████████████▍                        | 61/100 [00:09<00:04,  8.72it/s][A[A

 63%|███████████████████████████████████████▋                       | 63/100 [00:10<00:05,  7.04it/s][A[A

 65%|████████████████████████████████████████▉                      | 65/100 [00:10<00:06,  5.22it/s][A[A

 67%|██████████████████████████████████████████▏                    | 67/100 [00:10<00:05,  6.25it/s][A[A

 69%|███████████████████████████████████████████▍                   | 69/100 [00:11<00:04,  6.88it/s][A[A

 71%|████████████████████████████████████████████▋                  | 71/100 [00:11<00:04,  7.12it/s][A[A

 73%|████████████

Iteration : 99 
 , Train Loss : 18.222400903701782 
, Test Acc : 55.74074172973633 





  0%|                                                                        | 0/100 [00:00<?, ?it/s][A[A

  1%|▋                                                               | 1/100 [00:00<00:14,  6.70it/s][A[A

Iteration : 0 
 , Train Loss : 8.21332057317098 
, Test Acc : 16.29901885986328 





  3%|█▉                                                              | 3/100 [00:00<00:12,  7.49it/s][A[A

  5%|███▏                                                            | 5/100 [00:00<00:16,  5.84it/s][A[A

  7%|████▍                                                           | 7/100 [00:01<00:15,  6.19it/s][A[A

  9%|█████▊                                                          | 9/100 [00:01<00:13,  6.82it/s][A[A

 11%|██████▉                                                        | 11/100 [00:01<00:12,  7.07it/s][A[A

 13%|████████▏                                                      | 13/100 [00:01<00:11,  7.43it/s][A[A

 15%|█████████▍                                                     | 15/100 [00:02<00:11,  7.73it/s][A[A

 17%|██████████▋                                                    | 17/100 [00:02<00:10,  8.01it/s][A[A

 19%|███████████▉                                                   | 19/100 [00:02<00:09,  8.53it/s][A[A

 21%|████████████

Iteration : 50 
 , Train Loss : 2.3551865418752036 
, Test Acc : 58.946075439453125 





 53%|█████████████████████████████████▍                             | 53/100 [00:08<00:14,  3.15it/s][A[A

 55%|██████████████████████████████████▋                            | 55/100 [00:10<00:21,  2.09it/s][A[A

 57%|███████████████████████████████████▉                           | 57/100 [00:11<00:21,  1.97it/s][A[A

 59%|█████████████████████████████████████▏                         | 59/100 [00:11<00:15,  2.58it/s][A[A

 61%|██████████████████████████████████████▍                        | 61/100 [00:12<00:11,  3.35it/s][A[A

 63%|███████████████████████████████████████▋                       | 63/100 [00:12<00:08,  4.17it/s][A[A

 65%|████████████████████████████████████████▉                      | 65/100 [00:12<00:07,  4.87it/s][A[A

 67%|██████████████████████████████████████████▏                    | 67/100 [00:12<00:05,  5.82it/s][A[A

 69%|███████████████████████████████████████████▍                   | 69/100 [00:12<00:04,  6.64it/s][A[A

 71%|████████████

Iteration : 99 
 , Train Loss : 30.355083147684734 
, Test Acc : 34.68136978149414 





  0%|                                                                        | 0/100 [00:00<?, ?it/s][A[A

  1%|▋                                                               | 1/100 [00:00<00:36,  2.74it/s][A[A

Iteration : 0 
 , Train Loss : 8.413390159606934 
, Test Acc : 20.855613708496094 





  3%|█▉                                                              | 3/100 [00:01<00:50,  1.91it/s][A[A

  5%|███▏                                                            | 5/100 [00:03<01:22,  1.15it/s][A[A

  7%|████▍                                                           | 7/100 [00:08<02:17,  1.48s/it][A[A

  9%|█████▊                                                          | 9/100 [00:08<01:28,  1.03it/s][A[A

 11%|██████▉                                                        | 11/100 [00:09<00:59,  1.48it/s][A[A

 13%|████████▏                                                      | 13/100 [00:09<00:41,  2.11it/s][A[A

 15%|█████████▍                                                     | 15/100 [00:09<00:29,  2.89it/s][A[A

 17%|██████████▋                                                    | 17/100 [00:09<00:21,  3.85it/s][A[A

 19%|███████████▉                                                   | 19/100 [00:09<00:17,  4.56it/s][A[A

 21%|████████████

Iteration : 50 
 , Train Loss : 0.11185227148234844 
, Test Acc : 51.87165832519531 





 55%|██████████████████████████████████▋                            | 55/100 [00:13<00:03, 12.16it/s][A[A

 57%|███████████████████████████████████▉                           | 57/100 [00:13<00:03, 12.48it/s][A[A

 59%|█████████████████████████████████████▏                         | 59/100 [00:13<00:03, 12.13it/s][A[A

 61%|██████████████████████████████████████▍                        | 61/100 [00:13<00:03, 12.59it/s][A[A

 63%|███████████████████████████████████████▋                       | 63/100 [00:14<00:02, 12.67it/s][A[A

 65%|████████████████████████████████████████▉                      | 65/100 [00:14<00:02, 13.02it/s][A[A

 67%|██████████████████████████████████████████▏                    | 67/100 [00:14<00:02, 13.27it/s][A[A

 69%|███████████████████████████████████████████▍                   | 69/100 [00:14<00:02, 12.07it/s][A[A

 71%|████████████████████████████████████████████▋                  | 71/100 [00:14<00:02, 12.38it/s][A[A

 73%|████████████

Iteration : 99 
 , Train Loss : 0.08007687237113714 
, Test Acc : 60.96257019042969 





  0%|                                                                        | 0/100 [00:00<?, ?it/s][A[A

  1%|▋                                                               | 1/100 [00:00<00:15,  6.49it/s][A[A

Iteration : 0 
 , Train Loss : 6.013992071151733 
, Test Acc : 4.640151500701904 





  3%|█▉                                                              | 3/100 [00:00<00:13,  7.21it/s][A[A

  5%|███▏                                                            | 5/100 [00:00<00:16,  5.85it/s][A[A

  7%|████▍                                                           | 7/100 [00:01<00:13,  6.81it/s][A[A

  9%|█████▊                                                          | 9/100 [00:01<00:12,  7.41it/s][A[A

 11%|██████▉                                                        | 11/100 [00:01<00:11,  7.74it/s][A[A

 13%|████████▏                                                      | 13/100 [00:01<00:11,  7.40it/s][A[A

 15%|█████████▍                                                     | 15/100 [00:02<00:14,  5.95it/s][A[A

 17%|██████████▋                                                    | 17/100 [00:03<00:19,  4.34it/s][A[A

 19%|███████████▉                                                   | 19/100 [00:03<00:18,  4.42it/s][A[A

 21%|████████████

Iteration : 50 
 , Train Loss : 0.3179122656583786 
, Test Acc : 76.89393615722656 





 53%|█████████████████████████████████▍                             | 53/100 [00:07<00:08,  5.58it/s][A[A

 55%|██████████████████████████████████▋                            | 55/100 [00:08<00:10,  4.33it/s][A[A

 57%|███████████████████████████████████▉                           | 57/100 [00:09<00:11,  3.59it/s][A[A

 59%|█████████████████████████████████████▏                         | 59/100 [00:10<00:16,  2.51it/s][A[A

 61%|██████████████████████████████████████▍                        | 61/100 [00:11<00:18,  2.10it/s][A[A

 63%|███████████████████████████████████████▋                       | 63/100 [00:12<00:14,  2.48it/s][A[A

 64%|████████████████████████████████████████▎                      | 64/100 [00:12<00:12,  2.86it/s][A[A

 65%|████████████████████████████████████████▉                      | 65/100 [00:12<00:11,  3.03it/s][A[A

 67%|██████████████████████████████████████████▏                    | 67/100 [00:13<00:08,  3.89it/s][A[A

 69%|████████████

Iteration : 99 
 , Train Loss : 0.13882844522595406 
, Test Acc : 83.14393615722656 




  0%|                                                                         | 0/32 [00:00<?, ?it/s][A
  3%|██                                                               | 1/32 [00:00<00:07,  4.11it/s][A
  6%|████                                                             | 2/32 [00:01<00:25,  1.17it/s][A
  9%|██████                                                           | 3/32 [00:02<00:20,  1.44it/s][A
 12%|████████▏                                                        | 4/32 [00:02<00:19,  1.41it/s][A
 16%|██████████▏                                                      | 5/32 [00:03<00:18,  1.43it/s][A
 19%|████████████▏                                                    | 6/32 [00:04<00:17,  1.45it/s][A
 22%|██████████████▏                                                  | 7/32 [00:04<00:14,  1.67it/s][A
 25%|████████████████▎                                                | 8/32 [00:04<00:11,  2.10it/s][A
 28%|██████████████████▎                              

Iteration : 0 
 , Train Loss : 6.041143751144409 
, Test Acc : 9.326723098754883 





  3%|█▉                                                              | 3/100 [00:00<00:32,  2.94it/s][A[A

  4%|██▌                                                             | 4/100 [00:01<00:25,  3.83it/s][A[A

  5%|███▏                                                            | 5/100 [00:01<00:35,  2.70it/s][A[A

  6%|███▊                                                            | 6/100 [00:01<00:27,  3.36it/s][A[A

  7%|████▍                                                           | 7/100 [00:02<00:26,  3.55it/s][A[A

  8%|█████                                                           | 8/100 [00:02<00:22,  4.15it/s][A[A

  9%|█████▊                                                          | 9/100 [00:02<00:22,  4.10it/s][A[A

 10%|██████▎                                                        | 10/100 [00:02<00:19,  4.73it/s][A[A

 11%|██████▉                                                        | 11/100 [00:02<00:19,  4.52it/s][A[A

 12%|███████▌    

Iteration : 50 
 , Train Loss : 0.6049617528915405 
, Test Acc : 78.37081909179688 





 53%|█████████████████████████████████▍                             | 53/100 [00:15<00:44,  1.06it/s][A[A

 54%|██████████████████████████████████                             | 54/100 [00:15<00:31,  1.44it/s][A[A

 55%|██████████████████████████████████▋                            | 55/100 [00:16<00:30,  1.50it/s][A[A

 57%|███████████████████████████████████▉                           | 57/100 [00:17<00:21,  2.00it/s][A[A

 59%|█████████████████████████████████████▏                         | 59/100 [00:17<00:15,  2.64it/s][A[A

 61%|██████████████████████████████████████▍                        | 61/100 [00:17<00:11,  3.52it/s][A[A

 63%|███████████████████████████████████████▋                       | 63/100 [00:18<00:08,  4.37it/s][A[A

 65%|████████████████████████████████████████▉                      | 65/100 [00:18<00:06,  5.25it/s][A[A

 67%|██████████████████████████████████████████▏                    | 67/100 [00:18<00:06,  4.98it/s][A[A

 69%|████████████

Iteration : 99 
 , Train Loss : 0.6529629677534103 
, Test Acc : 80.82542419433594 





  0%|                                                                        | 0/100 [00:00<?, ?it/s][A[A

  1%|▋                                                               | 1/100 [00:00<01:28,  1.12it/s][A[A

Iteration : 0 
 , Train Loss : 7.405919530174949 
, Test Acc : 7.127121448516846 





  2%|█▎                                                              | 2/100 [00:01<00:52,  1.87it/s][A[A

  3%|█▉                                                              | 3/100 [00:01<00:54,  1.77it/s][A[A

  4%|██▌                                                             | 4/100 [00:02<00:44,  2.18it/s][A[A

  5%|███▏                                                            | 5/100 [00:02<00:43,  2.16it/s][A[A

  6%|███▊                                                            | 6/100 [00:02<00:36,  2.57it/s][A[A

  7%|████▍                                                           | 7/100 [00:03<00:40,  2.29it/s][A[A

  8%|█████                                                           | 8/100 [00:03<00:36,  2.50it/s][A[A

  9%|█████▊                                                          | 9/100 [00:04<00:36,  2.52it/s][A[A

 10%|██████▎                                                        | 10/100 [00:04<00:32,  2.79it/s][A[A

 11%|██████▉     

Iteration : 50 
 , Train Loss : 0.6329287595369599 
, Test Acc : 66.53329467773438 





 52%|████████████████████████████████▊                              | 52/100 [00:19<00:17,  2.79it/s][A[A

 53%|█████████████████████████████████▍                             | 53/100 [00:20<00:18,  2.49it/s][A[A

 54%|██████████████████████████████████                             | 54/100 [00:20<00:16,  2.72it/s][A[A

 55%|██████████████████████████████████▋                            | 55/100 [00:21<00:18,  2.46it/s][A[A

 56%|███████████████████████████████████▎                           | 56/100 [00:21<00:16,  2.69it/s][A[A

 57%|███████████████████████████████████▉                           | 57/100 [00:21<00:16,  2.59it/s][A[A

 58%|████████████████████████████████████▌                          | 58/100 [00:22<00:15,  2.72it/s][A[A

 59%|█████████████████████████████████████▏                         | 59/100 [00:22<00:16,  2.45it/s][A[A

 60%|█████████████████████████████████████▊                         | 60/100 [00:22<00:14,  2.72it/s][A[A

 61%|████████████

Iteration : 99 
 , Train Loss : 0.47590516033497726 
, Test Acc : 68.18305969238281 




  0%|                                                                         | 0/32 [00:00<?, ?it/s][A
  3%|██                                                               | 1/32 [00:00<00:09,  3.43it/s][A
  6%|████                                                             | 2/32 [00:00<00:08,  3.45it/s][A
  9%|██████                                                           | 3/32 [00:00<00:08,  3.36it/s][A
 12%|████████▏                                                        | 4/32 [00:01<00:09,  2.98it/s][A
 16%|██████████▏                                                      | 5/32 [00:01<00:11,  2.29it/s][A
 19%|████████████▏                                                    | 6/32 [00:02<00:09,  2.82it/s][A
 22%|██████████████▏                                                  | 7/32 [00:02<00:07,  3.32it/s][A
 28%|██████████████████▎                                              | 9/32 [00:02<00:04,  4.94it/s][A
 34%|██████████████████████                           

{0: [0,
  1,
  2,
  3,
  4,
  5,
  6,
  7,
  8,
  9,
  10,
  11,
  12,
  13,
  14,
  15,
  16,
  17,
  18,
  19,
  20,
  21,
  22,
  23,
  24,
  25,
  26,
  27,
  28,
  29,
  30,
  31]}

In [59]:
for cluster_id in tqdm(cluster_map.keys()):
    cluster_clients = [client_loaders[i] for i in cluster_map[cluster_id]]
    model_wts = [cluster_trainers[j].model.state_dict() for j in merge_clusters[cluster_id]]
    test_acc_cluster, model_avg_wt =avg_acc(model_wts,cluster_clients)
    torch.save(model_avg_wt, os.path.join(config['results_dir'], "refine_{}".format(refine_step), "merged_cluster_{}.pth".format(cluster_id)))
    test_acc += test_acc_cluster
test_acc = test_acc/len(cluster_map.keys())
torch.save(test_acc, os.path.join(config['results_dir'], "refine_{}".format(refine_step), "avg_acc.pth"))


100%|██████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  1.63it/s]


In [77]:
print(test_acc)

tensor(68.4903)


In [79]:
class GlobalTrainer(BaseTrainer):
    def __init__(self,  config, save_dir):
        super(GlobalTrainer, self).__init__(config, save_dir)
        
    def train(self, client_data_list):
        num_clients = len(client_data_list)

        train_loss_list = []
        test_acc_list = []
        self.model.to(self.device)
        self.model.train()
        
        
        optimizer = OPTIMIZER_LIST[self.config["optimizer"]](self.model.parameters(), **self.config["optimizer_params"])
        #eff_num_workers = int(num_clients/(1 - 2*beta))
        # if eff_num_workers > 0:
        #     eff_batch_size = self.config["train_batch"]/eff_num_workers
        #     for i in range(num_clients):
        #         client_data_list[i].trainloader.batch_size = eff_batch_size
                
        for iteration in tqdm(range(self.config["iterations"])):
            trmean_buffer = {}
            for idx, param in self.model.named_parameters():
                if param.requires_grad:
                    trmean_buffer[idx] = []
            train_loss = 0
            #optimizer.zero_grad(set_to_none=True)

            for client in client_data_list:
                #if eff_num_workers>0:
                optimizer.zero_grad(set_to_none=True)
                (X,Y) = client.sample_batch()
                X = X.to(config["device"])
                Y = Y.to(config["device"])
                loss_func = nn.CrossEntropyLoss()
                out = self.model(X)
                loss = loss_func(out,Y)
                loss.backward()
                train_loss += loss.detach().cpu().numpy().item()
                
                with torch.no_grad():
                    for idx, param in self.model.named_parameters():
                        if param.requires_grad:
                            trmean_buffer[idx].append(param.grad.clone())
            train_loss = train_loss/num_clients
            optimizer.zero_grad()
            
            start_idx = 0
            end_idx = num_clients


            for idx, param in self.model.named_parameters():
                if param.requires_grad:
                    sorted, _  = torch.sort(torch.stack(trmean_buffer[idx], dim=0), dim=0)
                    new_grad = sorted[start_idx:end_idx,...].mean(dim=0)
                    param.grad = new_grad
                    trmean_buffer[idx] = []
            optimizer.step()
            
            train_loss_list.append(train_loss)
            test_acc = 0
            for client_data in client_data_list:
                test_acc += calc_acc(self.model, self.device, client_data, train=False)
            test_acc = test_acc/num_clients
            test_acc_list.append(test_acc)
            self.model.train()
            if iteration % self.config["save_freq"] == 0 or iteration == self.config["iterations"] - 1:
                self.save_model_weights()
                self.save_metrics(train_loss_list, test_acc_list, iteration)
            if iteration % self.config["print_freq"] == 0 or iteration == self.config["iterations"] - 1:
                print("Iteration : {} \n , Train Loss : {} \n, Test Acc : {} \n".format(iteration,  train_loss, test_acc))
                
        self.model.eval()
        self.model.cpu()


    def test(self, client_data_list):
        self.load_model_weights()
        self.model.eval()
        self.model.to(self.device)
        test_acc = 0
        for client_data in client_data_list:
            test_acc += calc_acc(self.model, self.device, client_data, train=False)
        self.model.cpu()
        return test_acc

global_trainer = GlobalTrainer(config, os.path.join(config["results_dir"], "global"))
global_trainer.train(client_loaders)


  1%|▋                                                               | 1/100 [00:00<00:54,  1.83it/s]

Iteration : 0 
 , Train Loss : 7.147231712937355 
, Test Acc : 4.0661163330078125 



 51%|████████████████████████████████▏                              | 51/100 [00:22<00:25,  1.92it/s]

Iteration : 50 
 , Train Loss : 0.8960115159861743 
, Test Acc : 73.05399322509766 



100%|██████████████████████████████████████████████████████████████| 100/100 [01:42<00:00,  1.03s/it]

Iteration : 99 
 , Train Loss : 0.5325381550937891 
, Test Acc : 77.18030548095703 






In [80]:
del global_trainer
import gc
gc.collect()
torch.cuda.empty_cache()