In [1]:
%load_ext autoreload
%autoreload 2

In [2]:




import torch
import numpy as np
from torch.utils.data import Dataset, DataLoader 
from abc import ABC
from tqdm import tqdm
import torchvision
import os
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import random
import json
from collections import OrderedDict
import shutil

In [5]:
shutil.rmtree("./experiments/results/femnist/seed_9998")

In [3]:
torch.cuda.empty_cache()

301

In [3]:
config = {}


In [4]:
config["seed"] = 0
seed = config["seed"]
os.environ['PYTHONHASHSEED'] = str(seed)
# Torch RNG
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
# Python RNG
np.random.seed(seed)
random.seed(seed)


In [5]:
config["participation_ratio"] = 0.05/6
#config["total_num_clients_per_cluster"] = 80
#config["num_clients_per_cluster"] = int(config["participation_ratio"]*config["total_num_clients_per_cluster"])
#config["num_clusters"] = 4
#config["num_clients"] = config["num_clients_per_cluster"]*config["num_clusters"]
config["dataset"] = "femnist"
#DATASET_LIB = {"mnist" : torchvision.datasets.MNIST, "emnist": torchvision.datasets.EMNIST, "cifar10": torchvision.datasets.CIFAR10}
config["dataset_dir"] = "/base_vol/femnist/data"
config["results_dir"] = "./experiments/results"
config["results_dir"] = os.path.join(config["results_dir"], config["dataset"], "seed_{}".format(seed))
os.makedirs(config["results_dir"], exist_ok=True)


In [6]:
from collections import defaultdict
def read_dir(data_dir):
    clients = []
    groups = []
    data = defaultdict(lambda : None)

    files = os.listdir(data_dir)
    files = [f for f in files if f.endswith('.json')]
    for f in files:
        file_path = os.path.join(data_dir,f)
        with open(file_path, 'r') as inf:
            cdata = json.load(inf)
        clients.extend(cdata['users'])
        if 'hierarchies' in cdata:
            groups.extend(cdata['hierarchies'])
        data.update(cdata['user_data'])

    clients = list(sorted(data.keys()))
    return clients, groups, data


def read_data(train_data_dir, test_data_dir):
    '''parses data in given train and test data directories
    assumes:
    - the data in the input directories are .json files with 
        keys 'users' and 'user_data'
    - the set of train set users is the same as the set of test set users
    
    Return:
        clients: list of client ids
        groups: list of group ids; empty list if none found
        train_data: dictionary of train data
        test_data: dictionary of test data
    '''
    train_clients, train_groups, train_data = read_dir(train_data_dir)
    test_clients, test_groups, test_data = read_dir(test_data_dir)

    assert train_clients == test_clients
    assert train_groups == test_groups

    return train_clients, train_groups, train_data, test_data


In [7]:
config["total_clients"], _, train_data, test_data = read_data(os.path.join(config["dataset_dir"],"train"), os.path.join(config["dataset_dir"],"test"))


In [8]:
class ClientDataset(Dataset):
    def __init__(self, data,transforms = None):
        super(ClientDataset,self).__init__()
        self.data = data[0]
        self.labels = data[1]
        self.transforms = transforms

    def __len__(self):
        return len(self.data)
    
    def __getitem__(self,idx):
        idx_data = self.data[idx].reshape(28,28)
        if self.transforms is not None:
            transformed_data =self.transforms(idx_data)
        else:
            transformed_data = idx_data
        idx_labels = self.labels[idx]
        return (transformed_data.float(), idx_labels)

class Client():
    def __init__(self, train_data, test_data, client_id,  train_transforms, test_transforms, train_batch_size, test_batch_size, save_dir):
        self.trainset = ClientDataset(train_data, train_transforms)
        self.testset = ClientDataset(test_data, test_transforms)
        self.trainloader = DataLoader(self.trainset, batch_size = train_batch_size, shuffle=True)
        self.testloader = DataLoader(self.testset, batch_size = test_batch_size, shuffle=False)
        self.train_iterator = iter(self.trainloader)
        self.test_iterator = iter(self.testloader)
        self.client_id = client_id
        self.save_dir = os.path.join(save_dir, "init", "client_{}".format(client_id))

    def sample_batch(self, train=True):
        iterator = self.train_iterator if train else self.test_iterator
        try:
            (data, labels) = next(iterator)
        except StopIteration:
            if train:
                self.train_iterator = iter(self.trainloader)
                iterator = self.train_iterator
            else:
                self.test_iterator = iter(self.testloader)
                iterator = self.test_iterator
            (data, labels) = next(iterator)
        return (data, labels)


In [9]:
config["train_batch"] = 64
config["test_batch"] = 512
config["num_clients"] = int(len(config["total_clients"])/6)


In [10]:
## Generate new clients
selected_clients_path = os.path.join(config["results_dir"], "selected_clients.json")
if os.path.exists(selected_clients_path):
    with open(selected_clients_path, "r") as f:
        config["selected_clients"] = json.load(f)
else:
    config["selected_clients"] = random.sample(config["total_clients"], config["num_clients"])
    with open(selected_clients_path, "w") as f:
        json.dump(config["selected_clients"], f)

In [11]:
client_loaders = []
for client_id in config["selected_clients"]:
        client_loaders.append(
            Client(
                (np.array(train_data[client_id]['x']), np.array(train_data[client_id]['y'])),
                (np.array(test_data[client_id]['x']), np.array(test_data[client_id]['y'])),
                client_id,
                train_transforms=torchvision.transforms.ToTensor(),
                test_transforms=torchvision.transforms.ToTensor(),
                train_batch_size=config["train_batch"],
                test_batch_size=config["test_batch"],
                save_dir=config["results_dir"],
            )
        )


In [12]:
class SimpleCNN(torch.nn.Module):

    def __init__(self, h1=2048):
        super().__init__()
        self.conv1 = torch.nn.Conv2d(1, 32, kernel_size = (5,5), padding="same")
        self.pool1 = torch.nn.MaxPool2d((2,2), stride=2)
        self.conv2 = torch.nn.Conv2d(32, 64, kernel_size= (5,5), padding = "same")
        self.pool2 = torch.nn.MaxPool2d((2,2), stride=2)
        self.fc1 = torch.nn.Linear(64*7*7, 2048)
        self.fc2 = torch.nn.Linear(2048, 62)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.pool1(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = self.pool2(x)
        x = x.flatten(start_dim=1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x


In [13]:
def set_weights(model):
    model_wt = torch.load('/base_vol/model_wt_dict.pt')
    new_wts = OrderedDict()
    new_wts['fc2.weight'] = torch.Tensor(model_wt["dense_1/kernel"]).t()
    new_wts['fc2.bias'] = torch.Tensor(model_wt["dense_1/bias"])
    new_wts['fc1.weight'] = torch.Tensor(model_wt["dense/kernel"]).t()
    new_wts['fc1.bias'] = torch.Tensor(model_wt["dense/bias"])
    new_wts["conv1.weight"] = torch.Tensor(model_wt["conv2d/kernel"]).permute(3,2,0,1)
    new_wts["conv2.weight"] = torch.Tensor(model_wt["conv2d_1/kernel"]).permute(3,2,0,1)
    new_wts["conv1.bias"] = torch.Tensor(model_wt["conv2d/bias"])
    new_wts["conv2.bias"] = torch.Tensor(model_wt["conv2d_1/bias"])
    model.load_state_dict(new_wts)
    return freeze_layers(model)
def freeze_layers(model):
    model.conv1.weight.requires_grad =False
    model.conv2.weight.requires_grad =False
    model.fc1.weight.requires_grad =True
    model.fc2.weight.requires_grad =True
    model.conv1.bias.requires_grad =False
    model.conv2.bias.requires_grad =False
    model.fc1.bias.requires_grad =True
    model.fc2.bias.requires_grad =True
    return model

In [14]:
def calc_acc(model, device, client_data, train):
    loader = client_data.trainloader if train else client_data.testloader
    model.eval()
    model.to(device)
    num_corr = 0
    tot_num = 0
    with torch.no_grad():
        for (X,Y) in loader:
            X = X.to(device)
            pred = model(X).argmax(axis=1).detach().cpu()
            num_corr += (Y == pred).float().sum()
            tot_num += Y.shape[0]
    acc = num_corr/tot_num
    acc *= 100.0
    return acc


In [15]:
class BaseTrainer(ABC):
    def __init__(self,config, save_dir):
        super(BaseTrainer, self).__init__()
        self.model = set_weights(MODEL_LIST[config["model"]](**config["model_params"]))
        self.save_dir = save_dir
        self.device = config["device"]
        self.loss_func = LOSSES[config["loss_func"]]
        self.config = config
        os.makedirs(self.save_dir, exist_ok=True)

    def train(self):
        raise NotImplementedError
    
    def test(self):
        raise NotImplementedError

    def load_model_weights(self):
        model_path  = os.path.join(self.save_dir, "model.pth")
        if os.path.exists(model_path):
            self.model.load_state_dict(torch.load(model_path))
        else:
            print("No model present at path : {}".format())

    def save_model_weights(self):
        model_path  = os.path.join(self.save_dir, "model.pth")
        torch.save(self.model.state_dict(), model_path)
    def save_metrics(self, train_loss, test_acc, iteration):
        torch.save({"train_loss": train_loss,  "test_acc" : test_acc}, os.path.join(self.save_dir,"metrics_{}.pkl".format(iteration)))

class ClientTrainer(BaseTrainer):
    def __init__(self,  config, save_dir,client_id):
        super(ClientTrainer, self).__init__(config, save_dir)
        self.client_id = client_id
    
    def train(self, client_data):
        train_loss_list = []
        test_acc_list = []
        self.model.to(self.device)
        self.model.train()
        optimizer = OPTIMIZER_LIST[self.config["optimizer"]](self.model.parameters(), **self.config["optimizer_params"])
        for iteration in range(self.config["iterations"]):
            self.model.zero_grad()
            (X,Y) = client_data.sample_batch(train=True)
            X = X.to(self.device)
            Y = Y.to(self.device)
            out = self.model(X)
            loss = self.loss_func(out, Y)
            loss.backward()
            optimizer.step()
            train_loss = loss.detach().cpu().numpy().item()
            train_loss_list.append(train_loss)
            test_acc = calc_acc(self.model, self.device, client_data, train=False)
            test_acc_list.append(test_acc)
            self.model.train()
            if iteration % self.config["save_freq"] == 0 or iteration == self.config["iterations"] - 1:
                self.save_model_weights()
                self.save_metrics(train_loss_list, test_acc_list, iteration)
            if iteration % self.config["print_freq"] == 0 or iteration == self.config["iterations"] - 1:
                print("Iteration : {} \n , Train Loss : {} \n, Test Acc : {} \n".format(iteration,  train_loss, test_acc))
                
        self.model.eval()
        self.model.cpu()


    def test(self, client_data):
        self.load_model_weights()
        self.model.eval()
        self.model.to(self.device)
        acc =  calc_acc(self.model, client_data)
        self.model.cpu()
        return acc




In [16]:
  
MODEL_LIST = {"cnn" : SimpleCNN}
OPTIMIZER_LIST = {"sgd": optim.SGD, "adam": optim.Adam}
LOSSES = {"cross_entropy": nn.CrossEntropyLoss()}
# config["save_dir"] = os.path.join("./results")
config["iterations"] = 100
config["optimizer_params"] = {"lr":0.001}
config["save_freq"] = 2
config["print_freq"]  = 50
config["model"] = "cnn"
config["optimizer"] = "adam"
config["loss_func"] = "cross_entropy"
#config["model_params"] = {"num_channels": 1 , "num_classes"  : 62}
config["model_params"] = {}
config["device"] = torch.device("cuda:0")
import pickle






In [21]:
client_trainers = [ClientTrainer(config,os.path.join(config["results_dir"], "init", client_id), client_id) for client_id in config["selected_clients"]]


In [23]:
for i in tqdm(range(len(config["selected_clients"]))):
    client_trainers[i].train(client_loaders[i])


  0%|                                                                         | 0/32 [00:00<?, ?it/s]

Iteration : 0 
 , Train Loss : 8.403800010681152 
, Test Acc : 12.5 

Iteration : 50 
 , Train Loss : 0.036591965705156326 
, Test Acc : 68.75 



  3%|██                                                               | 1/32 [00:16<08:43, 16.88s/it]

Iteration : 99 
 , Train Loss : 0.025636300444602966 
, Test Acc : 68.75 

Iteration : 0 
 , Train Loss : 5.6712965965271 
, Test Acc : 0.0 

Iteration : 50 
 , Train Loss : 0.011240771971642971 
, Test Acc : 0.0 



  6%|████                                                             | 2/32 [00:26<06:22, 12.75s/it]

Iteration : 99 
 , Train Loss : 0.004137152340263128 
, Test Acc : 0.0 

Iteration : 0 
 , Train Loss : 9.131110191345215 
, Test Acc : 14.285715103149414 

Iteration : 50 
 , Train Loss : 0.016862623393535614 
, Test Acc : 64.28571319580078 



  9%|██████                                                           | 3/32 [00:44<07:12, 14.90s/it]

Iteration : 99 
 , Train Loss : 0.002873397897928953 
, Test Acc : 71.42857360839844 

Iteration : 0 
 , Train Loss : 6.688961505889893 
, Test Acc : 8.0 

Iteration : 50 
 , Train Loss : 0.16330811381340027 
, Test Acc : 60.000003814697266 



 12%|████████▏                                                        | 4/32 [01:00<07:13, 15.48s/it]

Iteration : 99 
 , Train Loss : 0.037593211978673935 
, Test Acc : 60.000003814697266 

Iteration : 0 
 , Train Loss : 7.267025947570801 
, Test Acc : 18.75 

Iteration : 50 
 , Train Loss : 0.1442040055990219 
, Test Acc : 75.0 



 16%|██████████▏                                                      | 5/32 [01:37<10:29, 23.32s/it]

Iteration : 99 
 , Train Loss : 0.10409393161535263 
, Test Acc : 68.75 

Iteration : 0 
 , Train Loss : 7.497664451599121 
, Test Acc : 28.571430206298828 

Iteration : 50 
 , Train Loss : 0.0009631660650484264 
, Test Acc : 71.42857360839844 



 19%|████████████▏                                                    | 6/32 [01:47<08:04, 18.63s/it]

Iteration : 99 
 , Train Loss : 0.0004981413367204368 
, Test Acc : 71.42857360839844 

Iteration : 0 
 , Train Loss : 7.037042617797852 
, Test Acc : 0.0 

Iteration : 50 
 , Train Loss : 0.031065799295902252 
, Test Acc : 20.0 



 22%|██████████████▏                                                  | 7/32 [01:55<06:18, 15.15s/it]

Iteration : 99 
 , Train Loss : 0.008788074366748333 
, Test Acc : 20.0 

Iteration : 0 
 , Train Loss : 6.559217929840088 
, Test Acc : 11.111111640930176 

Iteration : 50 
 , Train Loss : 0.09433613717556 
, Test Acc : 66.66667175292969 



 25%|████████████████▎                                                | 8/32 [02:04<05:21, 13.41s/it]

Iteration : 99 
 , Train Loss : 0.02506203018128872 
, Test Acc : 66.66667175292969 

Iteration : 0 
 , Train Loss : 6.957663536071777 
, Test Acc : 7.142857551574707 

Iteration : 50 
 , Train Loss : 0.26551195979118347 
, Test Acc : 50.0 



 28%|██████████████████▎                                              | 9/32 [03:01<10:17, 26.86s/it]

Iteration : 99 
 , Train Loss : 0.08959253877401352 
, Test Acc : 50.0 

Iteration : 0 
 , Train Loss : 8.287491798400879 
, Test Acc : 36.3636360168457 

Iteration : 50 
 , Train Loss : 0.0012678930070251226 
, Test Acc : 90.90909576416016 



 31%|████████████████████                                            | 10/32 [03:46<11:52, 32.37s/it]

Iteration : 99 
 , Train Loss : 0.0007202834822237492 
, Test Acc : 90.90909576416016 

Iteration : 0 
 , Train Loss : 7.022873401641846 
, Test Acc : 17.647058486938477 

Iteration : 50 
 , Train Loss : 0.04264714941382408 
, Test Acc : 58.82353210449219 



 34%|██████████████████████                                          | 11/32 [04:01<09:30, 27.19s/it]

Iteration : 99 
 , Train Loss : 0.014498859643936157 
, Test Acc : 58.82353210449219 

Iteration : 0 
 , Train Loss : 7.97418212890625 
, Test Acc : 100.0 

Iteration : 50 
 , Train Loss : 0.0 
, Test Acc : 100.0 



 38%|████████████████████████                                        | 12/32 [04:09<07:05, 21.25s/it]

Iteration : 99 
 , Train Loss : 0.0 
, Test Acc : 100.0 

Iteration : 0 
 , Train Loss : 6.342329502105713 
, Test Acc : 16.666667938232422 

Iteration : 50 
 , Train Loss : 0.10455552488565445 
, Test Acc : 72.22222137451172 



 41%|██████████████████████████                                      | 13/32 [04:16<05:23, 17.03s/it]

Iteration : 99 
 , Train Loss : 0.03691640496253967 
, Test Acc : 77.77777862548828 

Iteration : 0 
 , Train Loss : 6.055566310882568 
, Test Acc : 5.714285850524902 

Iteration : 50 
 , Train Loss : 0.21863023936748505 
, Test Acc : 65.71428680419922 



 44%|████████████████████████████                                    | 14/32 [04:28<04:36, 15.38s/it]

Iteration : 99 
 , Train Loss : 0.09655528515577316 
, Test Acc : 68.5714340209961 

Iteration : 0 
 , Train Loss : 6.021806240081787 
, Test Acc : 0.0 

Iteration : 50 
 , Train Loss : 0.12935380637645721 
, Test Acc : 63.15789031982422 



 47%|██████████████████████████████                                  | 15/32 [04:36<03:43, 13.16s/it]

Iteration : 99 
 , Train Loss : 0.027539853006601334 
, Test Acc : 68.42105102539062 

Iteration : 0 
 , Train Loss : 7.837393283843994 
, Test Acc : 18.75 

Iteration : 50 
 , Train Loss : 0.08174458146095276 
, Test Acc : 87.5 



 50%|████████████████████████████████                                | 16/32 [04:44<03:06, 11.64s/it]

Iteration : 99 
 , Train Loss : 0.06681635230779648 
, Test Acc : 87.5 

Iteration : 0 
 , Train Loss : 6.74637508392334 
, Test Acc : 9.523809432983398 

Iteration : 50 
 , Train Loss : 0.11120343953371048 
, Test Acc : 76.19047546386719 



 53%|██████████████████████████████████                              | 17/32 [05:09<03:57, 15.82s/it]

Iteration : 99 
 , Train Loss : 0.056501198559999466 
, Test Acc : 76.19047546386719 

Iteration : 0 
 , Train Loss : 7.328262805938721 
, Test Acc : 5.555555820465088 

Iteration : 50 
 , Train Loss : 0.12258609384298325 
, Test Acc : 72.22222137451172 



 56%|████████████████████████████████████                            | 18/32 [05:18<03:11, 13.67s/it]

Iteration : 99 
 , Train Loss : 0.04092942178249359 
, Test Acc : 72.22222137451172 

Iteration : 0 
 , Train Loss : 5.433204650878906 
, Test Acc : 33.333335876464844 

Iteration : 50 
 , Train Loss : 0.08512686938047409 
, Test Acc : 77.77777862548828 



 59%|██████████████████████████████████████                          | 19/32 [05:26<02:34, 11.87s/it]

Iteration : 99 
 , Train Loss : 0.017638977617025375 
, Test Acc : 66.66667175292969 

Iteration : 0 
 , Train Loss : 8.199926376342773 
, Test Acc : 23.52941131591797 

Iteration : 50 
 , Train Loss : 0.08586252480745316 
, Test Acc : 52.94117736816406 



 62%|████████████████████████████████████████                        | 20/32 [05:40<02:32, 12.72s/it]

Iteration : 99 
 , Train Loss : 0.042928822338581085 
, Test Acc : 52.94117736816406 

Iteration : 0 
 , Train Loss : 7.489833831787109 
, Test Acc : 12.5 

Iteration : 50 
 , Train Loss : 0.34132108092308044 
, Test Acc : 62.5 



 66%|██████████████████████████████████████████                      | 21/32 [05:53<02:19, 12.66s/it]

Iteration : 99 
 , Train Loss : 0.08666824549436569 
, Test Acc : 62.5 

Iteration : 0 
 , Train Loss : 6.907412052154541 
, Test Acc : 0.0 

Iteration : 50 
 , Train Loss : 0.1850575953722 
, Test Acc : 67.64705657958984 



 69%|████████████████████████████████████████████                    | 22/32 [06:07<02:11, 13.15s/it]

Iteration : 99 
 , Train Loss : 0.04333963617682457 
, Test Acc : 70.5882339477539 

Iteration : 0 
 , Train Loss : 7.1721086502075195 
, Test Acc : 8.333333969116211 

Iteration : 50 
 , Train Loss : 0.10060359537601471 
, Test Acc : 79.16667175292969 



 72%|██████████████████████████████████████████████                  | 23/32 [06:26<02:14, 14.99s/it]

Iteration : 99 
 , Train Loss : 0.010694082826375961 
, Test Acc : 83.33332824707031 

Iteration : 0 
 , Train Loss : 6.145168781280518 
, Test Acc : 9.67741870880127 

Iteration : 50 
 , Train Loss : 0.2902585566043854 
, Test Acc : 58.06451416015625 



 75%|████████████████████████████████████████████████                | 24/32 [06:41<01:58, 14.83s/it]

Iteration : 99 
 , Train Loss : 0.12190034240484238 
, Test Acc : 64.51612854003906 

Iteration : 0 
 , Train Loss : 5.732304573059082 
, Test Acc : 14.634145736694336 

Iteration : 50 
 , Train Loss : 0.3019353449344635 
, Test Acc : 87.80487823486328 



 78%|██████████████████████████████████████████████████              | 25/32 [07:15<02:25, 20.75s/it]

Iteration : 99 
 , Train Loss : 0.1557682901620865 
, Test Acc : 87.80487823486328 

Iteration : 0 
 , Train Loss : 7.57725715637207 
, Test Acc : 9.090909004211426 

Iteration : 50 
 , Train Loss : 0.04681456461548805 
, Test Acc : 63.6363639831543 



 81%|████████████████████████████████████████████████████            | 26/32 [07:23<01:41, 16.86s/it]

Iteration : 99 
 , Train Loss : 0.02719501033425331 
, Test Acc : 72.7272720336914 

Iteration : 0 
 , Train Loss : 6.136054992675781 
, Test Acc : 15.384615898132324 

Iteration : 50 
 , Train Loss : 0.09329338371753693 
, Test Acc : 53.84615707397461 



 84%|██████████████████████████████████████████████████████          | 27/32 [07:29<01:08, 13.63s/it]

Iteration : 99 
 , Train Loss : 0.027593420818448067 
, Test Acc : 53.84615707397461 

Iteration : 0 
 , Train Loss : 6.9609832763671875 
, Test Acc : 11.764705657958984 

Iteration : 50 
 , Train Loss : 0.06471414864063263 
, Test Acc : 64.70588684082031 



 88%|████████████████████████████████████████████████████████        | 28/32 [07:37<00:46, 11.73s/it]

Iteration : 99 
 , Train Loss : 0.05900261551141739 
, Test Acc : 64.70588684082031 

Iteration : 0 
 , Train Loss : 8.158488273620605 
, Test Acc : 6.25 

Iteration : 50 
 , Train Loss : 0.035312116146087646 
, Test Acc : 68.75 



 91%|██████████████████████████████████████████████████████████      | 29/32 [08:05<00:50, 16.70s/it]

Iteration : 99 
 , Train Loss : 0.003174713347107172 
, Test Acc : 68.75 

Iteration : 0 
 , Train Loss : 7.357988357543945 
, Test Acc : 17.647058486938477 

Iteration : 50 
 , Train Loss : 0.054672472178936005 
, Test Acc : 76.47058868408203 



 94%|████████████████████████████████████████████████████████████    | 30/32 [08:13<00:28, 14.14s/it]

Iteration : 99 
 , Train Loss : 0.01016292441636324 
, Test Acc : 76.47058868408203 

Iteration : 0 
 , Train Loss : 6.999804496765137 
, Test Acc : 14.285715103149414 

Iteration : 50 
 , Train Loss : 0.020850175991654396 
, Test Acc : 78.57142639160156 



 97%|██████████████████████████████████████████████████████████████  | 31/32 [08:21<00:12, 12.36s/it]

Iteration : 99 
 , Train Loss : 0.00518223037943244 
, Test Acc : 85.71428680419922 

Iteration : 0 
 , Train Loss : 5.453310489654541 
, Test Acc : 18.18181800842285 

Iteration : 50 
 , Train Loss : 0.08098278194665909 
, Test Acc : 81.81818389892578 



100%|████████████████████████████████████████████████████████████████| 32/32 [08:45<00:00, 16.43s/it]

Iteration : 99 
 , Train Loss : 0.013384316116571426 
, Test Acc : 81.81818389892578 






In [21]:
## Load model weights
for trainer in client_trainers:
    trainer.load_model_weights()


In [24]:
local_acc = 0
for i in range(len(client_loaders)):
    local_acc += calc_acc(client_trainers[i].model, config["device"], client_loaders[i], train=False)
local_acc = local_acc/len(client_loaders)
print(f"Local Accuracy : {local_acc}")

Local Accuracy : 67.80694580078125


In [2]:
prin

NameError: name 'client_trainers' is not defined

In [18]:
a = 0
for i, trainer in enumerate(client_trainers):
    a+= calc_acc(trainer.model, config["device"], client_loaders[i], train=False)

In [19]:
print(a/len(client_trainers))

tensor(66.1461)


In [59]:
# config["iterations"] = 100
# config["optimizer_params"] = {"lr":0.001}
# config["save_freq"] = 2
# config["print_freq"]  = 40
# config["model"] = "cnn"
# config["optimizer"] = "adam"
# config["loss_func"] = "cross_entropy"
# #config["model_params"] = {"num_channels": 1 , "num_classes"  : 62}
# config["model_params"] = {}
# config["device"] = torch.device("cuda:0")

# client_id = config["selected_clients"][0]
# client_trainers[0] = ClientTrainer(config,os.path.join(config["results_dir"], "init", client_id), client_id)
# client_trainers[0].train(client_loaders[0])
import networkx as nx



In [25]:
import networkx as nx
G = nx.Graph()
G.add_nodes_from(range(config["num_clients"]))
import itertools
def model_weights_diff(w_1, w_2):
    norm_sq = 0
    assert w_1.keys() == w_2.keys(), "Model weights have different keys"
    for key in w_1.keys():
        norm_sq  += (w_1[key].cpu() - w_2[key].cpu()).norm()**2
    return np.sqrt(norm_sq)
wt = client_trainers[0].model.state_dict()


In [26]:
all_pairs = list(itertools.combinations(range(config["num_clients"]),2))
arr = []
for pair in all_pairs:
    w_1  = client_trainers[pair[0]].model.state_dict()
    w_2 = client_trainers[pair[1]].model.state_dict()
    norm_diff = model_weights_diff(w_1, w_2)
    arr.append(norm_diff)
#thresh = torch.mean(torch.tensor(arr))


In [27]:
thresh = arr[torch.tensor(arr).argsort()[int(0.3*len(arr))-1]]


In [28]:
clustering = []
def correlation_clustering(G):
    global clustering
    if len(G.nodes) == 0:
        return
    else:
        cluster = []
        new_cluster_pivot = random.sample(G.nodes,1)[0]
        cluster.append(new_cluster_pivot)
        neighbors = G[new_cluster_pivot].copy()
        for node in neighbors:
            cluster.append(node)
            G.remove_node(node)
        G.remove_node(new_cluster_pivot)
        clustering.append(cluster)
        correlation_clustering(G)


In [29]:
while True:
    G = nx.Graph()
    G.add_nodes_from(range(config["num_clients"]))
    for i in range(len(all_pairs)):
        if arr[i] < thresh:
            G.add_edge(all_pairs[i][0], all_pairs[i][1])
    G = G.to_undirected()
    clustering = []
    correlation_clustering(G)
    clusters = [cluster  for cluster in clustering if len(cluster) > 1 ]
    print(len(clusters))
    if len(clusters) >= 2:
        break
    
cluster_map = {i: clusters[i] for i in range(len(clusters))}
beta = 0.2

4


since Python 3.9 and will be removed in a subsequent version.
  new_cluster_pivot = random.sample(G.nodes,1)[0]


In [30]:
os.makedirs(os.path.join(config["results_dir"],"refine_0"),exist_ok=True)
with open(os.path.join(config["results_dir"],"refine_0", "cluster_maps.pkl"), 'wb') as handle:
    pickle.dump(cluster_map, handle, protocol=pickle.HIGHEST_PROTOCOL)

In [56]:
cluster_map

{0: [8, 2, 3, 21, 22, 27, 29, 30, 31],
 1: [0, 15, 24, 25],
 2: [6, 10, 17],
 3: [26, 5],
 4: [16, 13]}

In [31]:
class ClusterTrainer(BaseTrainer):
    def __init__(self,  config, save_dir,cluster_id):
        super(ClusterTrainer, self).__init__(config, save_dir)
        self.cluster_id = cluster_id
    
    def train(self, client_data_list):
        num_clients = len(client_data_list)

        train_loss_list = []
        test_acc_list = []
        self.model.to(self.device)
        self.model.train()
        
        
        optimizer = OPTIMIZER_LIST[self.config["optimizer"]](self.model.parameters(), **self.config["optimizer_params"])
        #eff_num_workers = int(num_clients/(1 - 2*beta))
        # if eff_num_workers > 0:
        #     eff_batch_size = self.config["train_batch"]/eff_num_workers
        #     for i in range(num_clients):
        #         client_data_list[i].trainloader.batch_size = eff_batch_size
                
        for iteration in tqdm(range(self.config["iterations"])):
            trmean_buffer = {}
            for idx, param in self.model.named_parameters():
                if param.requires_grad:
                    trmean_buffer[idx] = []
            train_loss = 0
            #optimizer.zero_grad(set_to_none=True)

            for client in client_data_list:
                #if eff_num_workers>0:
                optimizer.zero_grad(set_to_none=True)
                (X,Y) = client.sample_batch()
                X = X.to(config["device"])
                Y = Y.to(config["device"])
                loss_func = nn.CrossEntropyLoss()
                out = self.model(X)
                loss = loss_func(out,Y)
                loss.backward()
                train_loss += loss.detach().cpu().numpy().item()
                
                with torch.no_grad():
                    for idx, param in self.model.named_parameters():
                        if param.requires_grad:
                            trmean_buffer[idx].append(param.grad.clone())
            train_loss = train_loss/num_clients
            optimizer.zero_grad()
            
            start_idx = int(beta*num_clients)
            end_idx = int((1-beta)*num_clients)
            if end_idx <= start_idx + 1:
                start_idx = 0
                end_idx = num_clients


            for idx, param in self.model.named_parameters():
                if param.requires_grad:
                    sorted, _  = torch.sort(torch.stack(trmean_buffer[idx], dim=0), dim=0)
                    new_grad = sorted[start_idx:end_idx,...].mean(dim=0)
                    param.grad = new_grad
                    trmean_buffer[idx] = []
            optimizer.step()
            
            train_loss_list.append(train_loss)
            test_acc = 0
            for client_data in client_data_list:
                test_acc += calc_acc(self.model, self.device, client_data, train=False)
            test_acc = test_acc/num_clients
            test_acc_list.append(test_acc)
            self.model.train()
            if iteration % self.config["save_freq"] == 0 or iteration == self.config["iterations"] - 1:
                self.save_model_weights()
                self.save_metrics(train_loss_list, test_acc_list, iteration)
            if iteration % self.config["print_freq"] == 0 or iteration == self.config["iterations"] - 1:
                print("Iteration : {} \n , Train Loss : {} \n, Test Acc : {} \n".format(iteration,  train_loss, test_acc))
                
        self.model.eval()
        self.model.cpu()


    def test(self, client_data_list):
        self.load_model_weights()
        self.model.eval()
        self.model.to(self.device)
        test_acc = 0
        for client_data in client_data_list:
            test_acc += calc_acc(self.model, self.device, client_data, train=False)
        test_acc = test_acc/len(client_data_list)
        self.model.cpu()
        return test_acc


def avg_acc(model_wts, client_data_list):
    orig = model_wts[0]
    if len(model_wts) > 0:
        for wt in model_wts[1:]:
            for key in orig.keys():
                if orig[key].dtype == torch.float32:
                    orig[key] += wt[key] 
        for key in orig.keys():
            if orig[key].dtype == torch.float32:
                orig[key] = orig[key]/len(model_wts)
    model = SimpleCNN()
    model.load_state_dict(orig)
    model.to(memory_format = torch.channels_last).cuda()
    test_acc = 0
    for client_data in client_data_list:
        test_acc += calc_acc(model, torch.device("cuda:0"), client_data, train=False)
    test_acc = test_acc/len(client_data_list)
    return test_acc, orig


In [32]:
config["refine_steps"] = 2

In [None]:
import re

refine_step = 0
refine_path = os.path.join(config["results_dir"], f"refine_{refine_step}")
num_clusters = [x for x in os.listdir(refine_path) if re.match("cluster_\d", x) is not None]

with open(os.path.join(config["results_dir"], "refine_1", "cluster_maps.pkl"), "rb") as handle:
    cluster_map = pickle.load(handle)

In [None]:
import re
if re.match("cluster_\d", "cluster_1") is not None:
    print("here")
os.listdir(os.path.join(refine_path, "cluster_0", "model.pth"))

In [60]:
with open(os.path.join(config["results_dir"],"refine_1", "cluster_maps.pkl"), 'rb') as handle:
    cluster_map = pickle.load(handle)


In [64]:
(len(cluster_map[0]) *80.82542419433594 + len(cluster_map[1])*68.18305969238281)/(len(cluster_map[0]) + len(cluster_map[1]))

72.13379859924316

In [33]:
config["refine_steps"] = 2
for refine_step in tqdm(range(config["refine_steps"])):
    if os.path.exists(os.path.join(config["results_dir"],"refine_{}".format(refine_step), "cluster_maps.pkl")):
        shutil.rmtree(os.path.join(config["results_dir"],"refine_{}".format(refine_step)))
    beta = 0.15
    cluster_trainers = []
    
    for cluster_id in tqdm(cluster_map.keys()):
        cluster_clients = [client_loaders[i] for i in cluster_map[cluster_id]]
        cluster_trainer = ClusterTrainer(config, os.path.join(config['results_dir'], "refine_{}".format(refine_step), "cluster_{}".format(cluster_id)), cluster_id)
        cluster_trainer.train(cluster_clients)
        cluster_trainers.append(cluster_trainer)
    with open(os.path.join(config["results_dir"],"refine_{}".format(refine_step), "cluster_maps.pkl"), 'wb') as handle:
        pickle.dump(cluster_map, handle, protocol=pickle.HIGHEST_PROTOCOL)
    cluster_map_recluster = {}
    for key in cluster_map.keys():
        cluster_map_recluster[key] = []

    for i in tqdm(range(config["num_clients"])):
        w_node = client_trainers[i].model.state_dict()
        norm_diff = np.infty
        new_cluster_id = 0
        for cluster_id in cluster_map.keys():
            w_cluster = cluster_trainers[cluster_id].model.state_dict()
            curr_norm_diff = model_weights_diff(w_node, w_cluster)
            if norm_diff > curr_norm_diff:
                new_cluster_id = cluster_id
                norm_diff = curr_norm_diff
        
        cluster_map_recluster[new_cluster_id].append(i)
    keys = list(cluster_map_recluster.keys()).copy()
    for key in keys:
        if len(cluster_map_recluster[key]) == 0:
            cluster_map_recluster.pop(key)
    cluster_map = cluster_map_recluster

    
    G = nx.Graph()
    G.add_nodes_from(cluster_map.keys())

    all_pairs = list(itertools.combinations(cluster_map.keys(),2))
    for pair in tqdm(all_pairs):
        w_1  = cluster_trainers[pair[0]].model.state_dict()
        w_2 = cluster_trainers[pair[1]].model.state_dict()
        norm_diff = model_weights_diff(w_1, w_2)
        if norm_diff < thresh:
            G.add_edge(pair[0], pair[1])
    G = G.to_undirected()
    clustering = []        
    correlation_clustering(G)
    merge_clusters = [cluster  for cluster in clustering if len(cluster) > 0]
    
    #merge_cluster_map = {i: clusters[i] for i in range(len(clusters))}
    #clusters = list(nx.algorithms.clique.enumerate_all_cliques(G))
    cluster_map_new = {}
    for i in range(len(merge_clusters)):
        cluster_map_new[i] = []
        for j in merge_clusters[i]:
            cluster_map_new[i] += cluster_map[j]
    cluster_map = cluster_map_new
    test_acc = 0
    for cluster_id in tqdm(cluster_map.keys()):
        cluster_clients = [client_loaders[i] for i in cluster_map[cluster_id]]
        model_wts = [cluster_trainers[j].model.state_dict() for j in merge_clusters[cluster_id]]
        test_acc_cluster, model_avg_wt =avg_acc(model_wts,cluster_clients)
        torch.save(model_avg_wt, os.path.join(config['results_dir'], "refine_{}".format(refine_step), "merged_cluster_{}.pth".format(cluster_id)))
        test_acc += test_acc_cluster
    test_acc = test_acc/len(cluster_map.keys())
    torch.save(test_acc, os.path.join(config['results_dir'], "refine_{}".format(refine_step), "avg_acc.pth"))


  0%|                                                                          | 0/2 [00:00<?, ?it/s]
  0%|                                                                          | 0/4 [00:00<?, ?it/s][A

  0%|                                                                        | 0/100 [00:00<?, ?it/s][A[A

  1%|▋                                                               | 1/100 [00:00<00:14,  6.65it/s][A[A

Iteration : 0 
 , Train Loss : 6.860051155090332 
, Test Acc : 5.2920637130737305 





  3%|█▉                                                              | 3/100 [00:00<00:13,  7.25it/s][A[A

  5%|███▏                                                            | 5/100 [00:00<00:15,  5.98it/s][A[A

  7%|████▍                                                           | 7/100 [00:01<00:12,  7.17it/s][A[A

  9%|█████▊                                                          | 9/100 [00:01<00:12,  7.14it/s][A[A

 11%|██████▉                                                        | 11/100 [00:01<00:13,  6.50it/s][A[A

 13%|████████▏                                                      | 13/100 [00:01<00:12,  6.94it/s][A[A

 15%|█████████▍                                                     | 15/100 [00:02<00:10,  7.81it/s][A[A

 17%|██████████▋                                                    | 17/100 [00:02<00:09,  8.47it/s][A[A

 19%|███████████▉                                                   | 19/100 [00:02<00:10,  7.48it/s][A[A

 21%|████████████

Iteration : 50 
 , Train Loss : 0.8610170364379883 
, Test Acc : 62.780426025390625 





 53%|█████████████████████████████████▍                             | 53/100 [00:06<00:05,  8.99it/s][A[A

 55%|██████████████████████████████████▋                            | 55/100 [00:06<00:04,  9.47it/s][A[A

 57%|███████████████████████████████████▉                           | 57/100 [00:07<00:04,  9.77it/s][A[A

 59%|█████████████████████████████████████▏                         | 59/100 [00:07<00:04,  9.10it/s][A[A

 61%|██████████████████████████████████████▍                        | 61/100 [00:07<00:04,  9.35it/s][A[A

 63%|███████████████████████████████████████▋                       | 63/100 [00:07<00:04,  7.94it/s][A[A

 65%|████████████████████████████████████████▉                      | 65/100 [00:08<00:05,  6.48it/s][A[A

 67%|██████████████████████████████████████████▏                    | 67/100 [00:08<00:05,  5.98it/s][A[A

 69%|███████████████████████████████████████████▍                   | 69/100 [00:08<00:04,  6.74it/s][A[A

 71%|████████████

Iteration : 99 
 , Train Loss : 5.9776934623718265 
, Test Acc : 58.84391403198242 





  0%|                                                                        | 0/100 [00:00<?, ?it/s][A[A

  1%|▋                                                               | 1/100 [00:00<00:16,  5.94it/s][A[A

Iteration : 0 
 , Train Loss : 8.302072683970133 
, Test Acc : 11.180607795715332 





  3%|█▉                                                              | 3/100 [00:00<00:11,  8.11it/s][A[A

  5%|███▏                                                            | 5/100 [00:00<00:10,  9.08it/s][A[A

  7%|████▍                                                           | 7/100 [00:00<00:10,  9.28it/s][A[A

  9%|█████▊                                                          | 9/100 [00:01<00:10,  8.77it/s][A[A

 11%|██████▉                                                        | 11/100 [00:01<00:09,  9.09it/s][A[A

 13%|████████▏                                                      | 13/100 [00:01<00:10,  8.22it/s][A[A

 15%|█████████▍                                                     | 15/100 [00:01<00:11,  7.60it/s][A[A

 17%|██████████▋                                                    | 17/100 [00:02<00:10,  8.15it/s][A[A

 19%|███████████▉                                                   | 19/100 [00:02<00:11,  7.12it/s][A[A

 21%|████████████

Iteration : 50 
 , Train Loss : 0.9069405868649483 
, Test Acc : 70.22774505615234 





 53%|█████████████████████████████████▍                             | 53/100 [00:08<00:07,  6.64it/s][A[A

 55%|██████████████████████████████████▋                            | 55/100 [00:08<00:06,  7.32it/s][A[A

 57%|███████████████████████████████████▉                           | 57/100 [00:08<00:05,  7.51it/s][A[A

 59%|█████████████████████████████████████▏                         | 59/100 [00:08<00:05,  7.26it/s][A[A

 61%|██████████████████████████████████████▍                        | 61/100 [00:09<00:05,  7.20it/s][A[A

 63%|███████████████████████████████████████▋                       | 63/100 [00:09<00:05,  6.32it/s][A[A

 65%|████████████████████████████████████████▉                      | 65/100 [00:09<00:05,  6.77it/s][A[A

 67%|██████████████████████████████████████████▏                    | 67/100 [00:09<00:04,  7.67it/s][A[A

 69%|███████████████████████████████████████████▍                   | 69/100 [00:10<00:03,  8.33it/s][A[A

 71%|████████████

Iteration : 99 
 , Train Loss : 4.893404783986625 
, Test Acc : 66.44702911376953 





  0%|                                                                        | 0/100 [00:00<?, ?it/s][A[A

  1%|▋                                                               | 1/100 [00:00<00:12,  7.66it/s][A[A



Iteration : 0 
 , Train Loss : 7.267216682434082 
, Test Acc : 13.115544319152832 



  3%|█▉                                                              | 3/100 [00:00<00:10,  9.28it/s][A[A

  5%|███▏                                                            | 5/100 [00:00<00:09,  9.74it/s][A[A

  7%|████▍                                                           | 7/100 [00:00<00:09,  9.56it/s][A[A

  9%|█████▊                                                          | 9/100 [00:00<00:09,  9.52it/s][A[A

 11%|██████▉                                                        | 11/100 [00:01<00:09,  9.72it/s][A[A

 13%|████████▏                                                      | 13/100 [00:01<00:13,  6.56it/s][A[A

 15%|█████████▍                                                     | 15/100 [00:02<00:22,  3.72it/s][A[A

 17%|██████████▋                                                    | 17/100 [00:05<00:47,  1.75it/s][A[A

 19%|███████████▉                                                   | 19/100 [00:10<01:40,  1.24s/it][A[A

 21%|█████████████▏

Iteration : 50 
 , Train Loss : 0.5129024535417557 
, Test Acc : 68.73974609375 





 53%|█████████████████████████████████▍                             | 53/100 [00:19<00:06,  6.98it/s][A[A

 55%|██████████████████████████████████▋                            | 55/100 [00:19<00:05,  7.50it/s][A[A

 57%|███████████████████████████████████▉                           | 57/100 [00:19<00:06,  6.66it/s][A[A

 59%|█████████████████████████████████████▏                         | 59/100 [00:20<00:06,  6.19it/s][A[A

 61%|██████████████████████████████████████▍                        | 61/100 [00:21<00:09,  4.02it/s][A[A

 63%|███████████████████████████████████████▋                       | 63/100 [00:22<00:14,  2.63it/s][A[A

 65%|████████████████████████████████████████▉                      | 65/100 [00:23<00:17,  2.04it/s][A[A

 66%|█████████████████████████████████████████▌                     | 66/100 [00:23<00:14,  2.37it/s][A[A

 67%|██████████████████████████████████████████▏                    | 67/100 [00:25<00:19,  1.68it/s][A[A

 69%|████████████

Iteration : 99 
 , Train Loss : 0.5947947651147842 
, Test Acc : 72.68962860107422 





  0%|                                                                        | 0/100 [00:00<?, ?it/s][A[A

  1%|▋                                                               | 1/100 [00:00<01:21,  1.21it/s][A[A

Iteration : 0 
 , Train Loss : 5.938006401062012 
, Test Acc : 10.43360424041748 





  3%|█▉                                                              | 3/100 [00:02<01:12,  1.33it/s][A[A

  5%|███▏                                                            | 5/100 [00:03<01:13,  1.29it/s][A[A

  7%|████▍                                                           | 7/100 [00:05<01:21,  1.14it/s][A[A

  9%|█████▊                                                          | 9/100 [00:06<00:55,  1.63it/s][A[A

 11%|██████▉                                                        | 11/100 [00:06<00:37,  2.35it/s][A[A

 13%|████████▏                                                      | 13/100 [00:06<00:29,  2.96it/s][A[A

 15%|█████████▍                                                     | 15/100 [00:06<00:22,  3.83it/s][A[A

 17%|██████████▋                                                    | 17/100 [00:07<00:22,  3.67it/s][A[A

 19%|███████████▉                                                   | 19/100 [00:08<00:32,  2.49it/s][A[A

 21%|████████████

Iteration : 50 
 , Train Loss : 0.2583047188818455 
, Test Acc : 80.35230255126953 





 53%|█████████████████████████████████▍                             | 53/100 [00:35<01:03,  1.34s/it][A[A

 55%|██████████████████████████████████▋                            | 55/100 [00:37<00:54,  1.20s/it][A[A

 57%|███████████████████████████████████▉                           | 57/100 [00:38<00:42,  1.01it/s][A[A

 59%|█████████████████████████████████████▏                         | 59/100 [00:39<00:34,  1.19it/s][A[A

 61%|██████████████████████████████████████▍                        | 61/100 [00:40<00:26,  1.46it/s][A[A

 63%|███████████████████████████████████████▋                       | 63/100 [00:40<00:21,  1.74it/s][A[A

 65%|████████████████████████████████████████▉                      | 65/100 [00:42<00:20,  1.69it/s][A[A

 67%|██████████████████████████████████████████▏                    | 67/100 [00:43<00:19,  1.68it/s][A[A

 69%|███████████████████████████████████████████▍                   | 69/100 [00:44<00:20,  1.49it/s][A[A

 71%|████████████

Iteration : 99 
 , Train Loss : 0.14426066353917122 
, Test Acc : 80.35230255126953 




  0%|                                                                         | 0/32 [00:00<?, ?it/s][A
  3%|██                                                               | 1/32 [00:00<00:13,  2.26it/s][A
  6%|████                                                             | 2/32 [00:00<00:12,  2.32it/s][A
  9%|██████                                                           | 3/32 [00:01<00:10,  2.68it/s][A
 12%|████████▏                                                        | 4/32 [00:01<00:09,  2.90it/s][A
 16%|██████████▏                                                      | 5/32 [00:01<00:08,  3.08it/s][A
 19%|████████████▏                                                    | 6/32 [00:01<00:07,  3.42it/s][A
 22%|██████████████▏                                                  | 7/32 [00:02<00:07,  3.50it/s][A
 25%|████████████████▎                                                | 8/32 [00:02<00:06,  3.86it/s][A
 28%|██████████████████▎                              

Iteration : 0 
 , Train Loss : 7.023851171135902 
, Test Acc : 7.08223295211792 





  2%|█▎                                                              | 2/100 [00:00<00:23,  4.20it/s][A[A

  3%|█▉                                                              | 3/100 [00:00<00:24,  4.01it/s][A[A

  4%|██▌                                                             | 4/100 [00:00<00:22,  4.35it/s][A[A

  5%|███▏                                                            | 5/100 [00:01<00:22,  4.14it/s][A[A

  6%|███▊                                                            | 6/100 [00:01<00:21,  4.44it/s][A[A

  7%|████▍                                                           | 7/100 [00:01<00:22,  4.17it/s][A[A

  8%|█████                                                           | 8/100 [00:01<00:20,  4.41it/s][A[A

  9%|█████▊                                                          | 9/100 [00:02<00:21,  4.25it/s][A[A

 10%|██████▎                                                        | 10/100 [00:02<00:20,  4.48it/s][A[A

 11%|██████▉     

Iteration : 50 
 , Train Loss : 1.036673075548606 
, Test Acc : 68.19358825683594 





 52%|████████████████████████████████▊                              | 52/100 [00:19<00:19,  2.47it/s][A[A

 53%|█████████████████████████████████▍                             | 53/100 [00:20<00:19,  2.37it/s][A[A

 54%|██████████████████████████████████                             | 54/100 [00:20<00:18,  2.55it/s][A[A

 55%|██████████████████████████████████▋                            | 55/100 [00:21<00:18,  2.38it/s][A[A

 56%|███████████████████████████████████▎                           | 56/100 [00:21<00:17,  2.57it/s][A[A

 57%|███████████████████████████████████▉                           | 57/100 [00:22<00:17,  2.44it/s][A[A

 58%|████████████████████████████████████▌                          | 58/100 [00:22<00:16,  2.57it/s][A[A

 59%|█████████████████████████████████████▏                         | 59/100 [00:22<00:16,  2.43it/s][A[A

 60%|█████████████████████████████████████▊                         | 60/100 [00:23<00:15,  2.62it/s][A[A

 61%|████████████

Iteration : 99 
 , Train Loss : 0.713878551061498 
, Test Acc : 72.79914093017578 




  0%|                                                                         | 0/32 [00:00<?, ?it/s][A
 16%|██████████▏                                                      | 5/32 [00:00<00:02, 10.22it/s][A
 22%|██████████████▏                                                  | 7/32 [00:00<00:03,  8.22it/s][A
 25%|████████████████▎                                                | 8/32 [00:01<00:04,  5.59it/s][A
 28%|██████████████████▎                                              | 9/32 [00:01<00:05,  3.99it/s][A
 31%|████████████████████                                            | 10/32 [00:02<00:06,  3.46it/s][A
 34%|██████████████████████                                          | 11/32 [00:02<00:06,  3.43it/s][A
 38%|████████████████████████                                        | 12/32 [00:02<00:06,  3.14it/s][A
 41%|██████████████████████████                                      | 13/32 [00:03<00:06,  2.74it/s][A
 44%|████████████████████████████                     

{0: [0,
  1,
  2,
  3,
  4,
  5,
  6,
  7,
  8,
  9,
  10,
  11,
  12,
  13,
  14,
  15,
  16,
  17,
  18,
  19,
  20,
  21,
  22,
  23,
  24,
  25,
  26,
  27,
  28,
  29,
  30,
  31]}

In [34]:
print(test_acc)

tensor(72.7991)


In [19]:
class GlobalTrainer(BaseTrainer):
    def __init__(self,  config, save_dir):
        super(GlobalTrainer, self).__init__(config, save_dir)
        
    def train(self, client_data_list):
        num_clients = len(client_data_list)

        train_loss_list = []
        test_acc_list = []
        self.model.to(self.device)
        self.model.train()
        
        
        optimizer = OPTIMIZER_LIST[self.config["optimizer"]](self.model.parameters(), **self.config["optimizer_params"])
        #eff_num_workers = int(num_clients/(1 - 2*beta))
        # if eff_num_workers > 0:
        #     eff_batch_size = self.config["train_batch"]/eff_num_workers
        #     for i in range(num_clients):
        #         client_data_list[i].trainloader.batch_size = eff_batch_size
                
        for iteration in tqdm(range(self.config["iterations"])):
            trmean_buffer = {}
            for idx, param in self.model.named_parameters():
                if param.requires_grad:
                    trmean_buffer[idx] = []
            train_loss = 0
            #optimizer.zero_grad(set_to_none=True)

            for client in client_data_list:
                #if eff_num_workers>0:
                optimizer.zero_grad(set_to_none=True)
                (X,Y) = client.sample_batch()
                X = X.to(config["device"])
                Y = Y.to(config["device"])
                loss_func = nn.CrossEntropyLoss()
                out = self.model(X)
                loss = loss_func(out,Y)
                loss.backward()
                train_loss += loss.detach().cpu().numpy().item()
                
                with torch.no_grad():
                    for idx, param in self.model.named_parameters():
                        if param.requires_grad:
                            trmean_buffer[idx].append(param.grad.clone())
            train_loss = train_loss/num_clients
            optimizer.zero_grad()
            
            start_idx = 0
            end_idx = num_clients


            for idx, param in self.model.named_parameters():
                if param.requires_grad:
                    sorted, _  = torch.sort(torch.stack(trmean_buffer[idx], dim=0), dim=0)
                    new_grad = sorted[start_idx:end_idx,...].mean(dim=0)
                    param.grad = new_grad
                    trmean_buffer[idx] = []
            optimizer.step()
            
            train_loss_list.append(train_loss)
            test_acc = 0
            for client_data in client_data_list:
                test_acc += calc_acc(self.model, self.device, client_data, train=False)
            test_acc = test_acc/num_clients
            test_acc_list.append(test_acc)
            self.model.train()
            if iteration % self.config["save_freq"] == 0 or iteration == self.config["iterations"] - 1:
                self.save_model_weights()
                self.save_metrics(train_loss_list, test_acc_list, iteration)
            if iteration % self.config["print_freq"] == 0 or iteration == self.config["iterations"] - 1:
                print("Iteration : {} \n , Train Loss : {} \n, Test Acc : {} \n".format(iteration,  train_loss, test_acc))
                
        self.model.eval()
        self.model.cpu()


    def test(self, client_data_list):
        self.load_model_weights()
        self.model.eval()
        self.model.to(self.device)
        test_acc = 0
        for client_data in client_data_list:
            test_acc += calc_acc(self.model, self.device, client_data, train=False)
        self.model.cpu()
        return test_acc

global_trainer = GlobalTrainer(config, os.path.join(config["results_dir"], "global"))
global_trainer.train(client_loaders)


  1%|▋                                                               | 1/100 [00:00<00:32,  3.01it/s]

Iteration : 0 
 , Train Loss : 6.9710216373205185 
, Test Acc : 8.685996055603027 



 52%|████████████████████████████████▊                              | 52/100 [00:16<00:11,  4.34it/s]

Iteration : 50 
 , Train Loss : 0.8661993985515437 
, Test Acc : 69.55622100830078 



100%|██████████████████████████████████████████████████████████████| 100/100 [00:37<00:00,  2.64it/s]

Iteration : 99 
 , Train Loss : 0.5586700489147916 
, Test Acc : 72.58793640136719 






In [20]:
del global_trainer
import gc
gc.collect()
torch.cuda.empty_cache()