In [60]:
import os
import copy
import torch
import numpy as np
import config as cfg
import torch.nn as nn
from tqdm import tqdm
from torch.utils.data import DataLoader
from models.simpleNet import Net
from datamanager import *
from Fed_Algorithms import FedAvg
import sklearn.metrics as metrics
from collections import OrderedDict
from multiprocessing import pool, cpu_count
from multiprocessing import Process
from torch.optim import SGD

class Client():
    def __init__(self, client_id:str, model:nn.Module, data_info:dict=None, device:str=cfg.DEVICE):
        self.id = client_id
        #self.cfg = cfg
        
        self.__model = None
        self.device = device
        
        self.train_info, self.test_info = data_info['train'], data_info['test'] # 함수화하기
        self.trainset, self.testset = FEMNIST(self.train_info), FEMNIST(self.test_info)
        
    @property
    def model(self):             
        return self.__model

    @model.setter
    def model(self, model):
        self.__model = model
    
    def __len__(self):
        return len(self.trainset)
    
    def setup(self):
        self.train_loader = DataLoader(self.trainset, batch_size=16, shuffle=True)
        self.test_loader = DataLoader(self.testset, batch_size=16, shuffle=False)
        self.optimizer = SGD(self.model.parameters(), lr=0.01)     # TODO: utils.get_optimizer(cfg['optim']:str)
        self.criterion = nn.CrossEntropyLoss()                     # TODO: utils.get_loss(cfg['loss']:str)
        self.epochs = 10
    
    def local_train(self)->None:
        self.model.train()
        self.model.to(self.device)
        # TRAINING
        for epoch in range(self.epochs):
            for idx, batch in enumerate(self.train_loader):
                self.optimizer.zero_grad()
                X, Y = batch
                X, Y = X.to(self.device), Y.to(self.device)
                pred = self.model(X)
                loss = self.criterion(pred, Y)
                loss.backward()
                self.optimizer.step()
                if "cuda" in self.device : torch.cuda.empty_cache()
        # TESTING
        self.model.eval()
        self.model.to(self.device)
        with torch.no_grad():
            loss_trace, result_pred, result_anno = [], [], []
            for idx, batch in enumerate(self.train_loader):
                X, Y = batch
                X, Y = X.to(self.device), Y.to(self.device)
                pred = self.model(X)
                loss = self.criterion(pred, Y)
                loss_trace.append(loss.to('cpu').detach().numpy())
                pred_np  = pred.to('cpu').detach().numpy()
                pred_np  = np.argmax(pred_np, axis=1).squeeze()
                Y_np     = Y.to('cpu').detach().numpy().reshape(-1, 1).squeeze()
                result_pred = np.hstack((result_pred, pred_np))
                result_anno = np.hstack((result_anno, Y_np))
                if "cuda" in self.device : torch.cuda.empty_cache()
            train_acc = metrics.accuracy_score(y_true=result_anno, y_pred=result_pred)
            train_loss = np.average(loss_trace)
            self.model.to('cpu')
        
        print(f'=== Client {self.id} Finished Training {len(self)} samples ===')
        print(f'client:{self.id} | Train Acc:{train_acc*100:.2f} | Train Loss:{train_loss:.4f}')
    
    def local_test(self):
        self.model.eval()
        self.model.to(self.device)
        with torch.no_grad():
            loss_trace, result_pred, result_anno = [], [], []
            for idx, batch in enumerate(self.test_loader):
                X, Y = batch
                X, Y = X.to(self.device), Y.to(self.device)
                pred = self.model(X)
                loss = self.criterion(pred, Y)
                loss_trace.append(loss.to('cpu').detach().numpy())
                pred_np  = pred.to('cpu').detach().numpy()
                pred_np  = np.argmax(pred_np, axis=1).squeeze()
                Y_np     = Y.to('cpu').detach().numpy().reshape(-1, 1).squeeze()
                result_pred = np.hstack((result_pred, pred_np))
                result_anno = np.hstack((result_anno, Y_np))
                
                if "cuda" in self.device : torch.cuda.empty_cache()
                
            test_acc = metrics.accuracy_score(y_true=result_anno, y_pred=result_pred)
            test_loss = np.average(loss_trace)
            print(f'client:{self.id} | Test Acc:{test_acc*100:.2f} | Test Loss:{test_loss:.4f}')
            self.model.to('cpu')

class Server():
    def __init__(self, DM_dict:dict, algorithm:str=None):
        self.train_DM = DM_dict['train']
        self.test_DM = DM_dict['test']
        
        self.clients = None
        self.device = cfg.DEVICE
        
        self.global_model = Net()
        
        self.criterion = nn.CrossEntropyLoss()              # TODO: utils.get_loss(cfg['loss']:str)
        
        self.Algorithm = FedAvg.FedAvg                      # FedAVG 같은 aggrrgation method 들어감 TODO: utils.get_algortihm() 작성
        
        self.mp_flag = False

    def setup(self):
        self.clients = self.create_clients()
        self.data = FEMNIST(self.test_DM.global_test_data)
        self.dataloader = DataLoader(self.data, batch_size=256, shuffle=False)
        self.transmit_model()
        self.setup_clients()
        
        
    def create_clients(self, n_users:int=35):
        self.user_ids = self.test_DM.users
        self.user_ids = np.random.choice(self.user_ids, n_users, replace=False)
        clients = {}
        for user in self.user_ids:
            data_info = {'train':self.train_DM.data[user],\
                         'test':self.test_DM.data[user]}
            clients[user] = Client(client_id=user, model=self.global_model, data_info=data_info, device=self.device)
            clients[user].model = copy.deepcopy(self.global_model)
        return clients
    
    def setup_clients(self)->None:
        for k, client in tqdm(enumerate(self.clients), leave=False):
            self.clients[client].setup()
    
    def transmit_model(self, sampled_clients:list=None)->None:
        if sampled_clients == None:
            for client in tqdm(self.clients, leave=False):
                self.clients[client].model.load_state_dict(copy.deepcopy(self.global_model.state_dict()))
                # self.clients[client].model = copy.deepcopy(self.global_model)
        else:
            for client in tqdm(sampled_clients, leave=False):
                self.clients[client].model.load_state_dict(copy.deepcopy(self.global_model.state_dict()))
                # self.clients[client].model = copy.deepcopy(self.global_model)

        
    def sample_clients(self, n_participant:int=10)->np.array:
        assert n_participant <= len(self.user_ids), "Check 'n_participant <= len(self.clients)'"
        return np.random.choice(self.user_ids, n_participant, replace=False) # 입력된 수의 유저를 추출해서 반환

    def train_selected_clients(self, sampled_clients:list)->None:
        total_sample = 0
        for client in tqdm(sampled_clients, leave=False):
            self.clients[client].local_train()
            total_sample += len(self.clients[client])
        return total_sample
    
    def mp_train_selected_clients(self, procnum:int, client:str)->None:
        self.clients[client].local_train()
    
    def test_selected_models(self, sampled_clients):
        for client in sampled_clients:
            self.clients[client].local_test()

    def mp_test_selected_models(self, procnum:int, client:str):
        self.clients[client].local_test()
    
    def average_model(self, sampled_clients, coefficients):
        averaged_weights = OrderedDict()
        for it, client in tqdm(enumerate(sampled_clients), leave=False):
            local_weights = self.clients[client].model.state_dict()
            for key in self.global_model.state_dict().keys():
                if it == 0:
                    averaged_weights[key] = coefficients[it] * local_weights[key]
                else:
                    averaged_weights[key] += coefficients[it] * local_weights[key]
        self.global_model.load_state_dict(averaged_weights)
    
    def update_model(self, train_result:dict, layers:list=None):
        self.received_models, num_samples = [], []
        for result in train_result:
            self.received_models.append(result['model'])
            num_samples.append(result['num_sample'])
        state = self.Algorithm(self.received_models, num_samples, layers)
        self.global_model.load_state_dict(state)

    def train_federated_model(self):
        sampled_clients = self.sample_clients()
        print(f"CLIENTS {sampled_clients} ARE SELECTED!\n")
        
        if self.mp_flag:
            print("TRAIN WITH MP!\n")
            procs = []
            selected_total_size = []
            for idx, c in enumerate(sampled_clients):
                selected_total_size.append(len(self.clients[c]))
                proc = Process(target=self.mp_train_selected_clients, args=(idx, c))
                proc.start()
                procs.append(proc)
            for proc in procs:
                proc.join()
            selected_total_size = sum(selected_total_size)
            # with pool.ThreadPool(processes=cpu_count() - 1) as workhorse:
            #     selected_total_size = workhorse.map(self.mp_train_selected_clients, sampled_clients)
            
        else:
            print("TRAIN WITH SP!\n")
            selected_total_size = self.train_selected_clients(sampled_clients)

        # if not self.mp_flag:
        #     print("TEST WITH MP!\n")
        #     procs = []
        #     for idx, c in enumerate(sampled_clients):
        #         proc = Process(target=self.mp_test_selected_models, args=(idx, c))
        #         proc.start()
        #         procs.append(proc)
        #     for proc in procs:
        #         proc.join()
        #     # with pool.ThreadPool(processes=cpu_count() - 1) as workhorse:
        #     #     workhorse.map(self.mp_test_selected_models, sampled_clients)
        # else:
        print("TEST WITH SP!\n")
        self.test_selected_models(sampled_clients)
                
        mixing_coefficients = [len(self.clients[client]) / selected_total_size for client in sampled_clients]
        
        # print(f"mixing_coefficients:{mixing_coefficients}")
        # print(f'CLIENT {sampled_clients[0]} WEIGHT')
        # print(f'CLIENT {self.clients[sampled_clients[0]].model.state_dict()["fc1.weight"]} WEIGHT')
        # print(f'CLIENT {sampled_clients[1]} WEIGHT')
        # print(f'CLIENT {self.clients[sampled_clients[1]].model.state_dict()["fc1.weight"]} WEIGHT')
        # print(f'Server Weight before update')
        # print(f'CLIENT {self.global_model.state_dict()["fc1.weight"]} WEIGHT')
        self.average_model(sampled_clients, mixing_coefficients)
        # print(f'Server Weight after update')
        # print(f'CLIENT {self.global_model.state_dict()["fc1.weight"]} WEIGHT')
        self.transmit_model()
        
    def global_test(self):
        self.global_model.eval()
        self.global_model.to(self.device)
        
        with torch.no_grad():
            loss_trace, result_pred, result_anno = [], [], []
            for idx, batch in enumerate(self.dataloader):
                X, Y = batch
                X, Y = X.to(self.device), Y.to(self.device)
                pred = self.global_model(X)
                loss = self.criterion(pred, Y)
                loss_trace.append(loss.to('cpu').detach().numpy())
                pred_np  = pred.to('cpu').detach().numpy()
                pred_np  = np.argmax(pred_np, axis=1).squeeze()
                Y_np     = Y.to('cpu').detach().numpy().reshape(-1, 1).squeeze()
                result_pred = np.hstack((result_pred, pred_np))
                result_anno = np.hstack((result_anno, Y_np))
            self.acc = metrics.accuracy_score(y_true=result_anno, y_pred=result_pred)
            self.test_loss = np.average(loss_trace)
            print(f'Global Test Result | Acc:{self.acc*100:.2f}, Loss:{self.test_loss:.4f}')
            self.global_model.to('cpu')
    def fit(self):
        pass

In [4]:
import torch.multiprocessing as mp
from multiprocessing import Process

import torch.autograd as autograd

autograd.set_detect_anomaly(True)

mp.set_start_method('spawn')
PATH = cfg.DATAPATH['femnist']
    
file_dict = get_files(PATH)
    
TRAIN_DM = DataManager(file_dict['train'], is_train=True)
TEST_DM = DataManager(file_dict['test'], is_train=False)
DM_dict = {'train':TRAIN_DM,
           'test':TEST_DM}
print("DATA READY")
    
print(f"WORKING WITH {cfg.DEVICE}")
    

DATA READY
WORKING WITH cuda


In [61]:
server = Server(DM_dict)
print("SERVER READY")
    
server.setup()
print('===== ROUND 0 =====\nServer Setup Complete!')

SERVER READY


                                      

===== ROUND 0 =====
Server Setup Complete!




In [62]:
for i in range(10):
    print(f'===== ROUND {i+1} START! =====\n')
    server.train_federated_model()
    server.global_test()
        

  0%|          | 0/10 [00:00<?, ?it/s]

===== ROUND 1 START! =====

CLIENTS ['17' '25' '32' '16' '4' '7' '31' '1' '9' '26'] ARE SELECTED!

TRAIN WITH SP!



 10%|█         | 1/10 [00:13<02:02, 13.62s/it]

=== Client 17 Finished Training 3008 samples ===
client:17 | Train Acc:22.87 | Train Loss:3.2651


 20%|██        | 2/10 [00:42<02:59, 22.45s/it]

=== Client 25 Finished Training 6286 samples ===
client:25 | Train Acc:18.12 | Train Loss:3.1497


 30%|███       | 3/10 [00:55<02:09, 18.46s/it]

=== Client 32 Finished Training 3022 samples ===
client:32 | Train Acc:23.00 | Train Loss:3.2049


 40%|████      | 4/10 [01:09<01:38, 16.47s/it]

=== Client 16 Finished Training 2965 samples ===
client:16 | Train Acc:14.33 | Train Loss:3.2624


 50%|█████     | 5/10 [01:32<01:33, 18.70s/it]

=== Client 4 Finished Training 4955 samples ===
client:4 | Train Acc:10.19 | Train Loss:3.5545


 60%|██████    | 6/10 [01:58<01:25, 21.42s/it]

=== Client 7 Finished Training 5900 samples ===
client:7 | Train Acc:7.61 | Train Loss:3.5613


 70%|███████   | 7/10 [02:10<00:55, 18.41s/it]

=== Client 31 Finished Training 2677 samples ===
client:31 | Train Acc:18.12 | Train Loss:3.3102


 80%|████████  | 8/10 [02:24<00:33, 16.96s/it]

=== Client 1 Finished Training 3071 samples ===
client:1 | Train Acc:14.62 | Train Loss:3.2901


 90%|█████████ | 9/10 [02:51<00:20, 20.16s/it]

=== Client 9 Finished Training 5974 samples ===
client:9 | Train Acc:13.34 | Train Loss:3.4158


                                               

=== Client 26 Finished Training 6253 samples ===
client:26 | Train Acc:17.77 | Train Loss:3.2414
TEST WITH SP!

client:17 | Test Acc:22.67 | Test Loss:3.4334
client:25 | Test Acc:18.48 | Test Loss:3.1034
client:32 | Test Acc:23.70 | Test Loss:3.1894
client:16 | Test Acc:12.39 | Test Loss:3.2665
client:4 | Test Acc:9.82 | Test Loss:3.5355
client:7 | Test Acc:8.56 | Test Loss:3.5411
client:31 | Test Acc:14.01 | Test Loss:3.3133
client:1 | Test Acc:11.40 | Test Loss:3.2704


                                      

client:9 | Test Acc:12.00 | Test Loss:3.3579
client:26 | Test Acc:15.06 | Test Loss:3.2951


  0%|          | 0/10 [00:00<?, ?it/s]

Global Test Result | Acc:16.92, Loss:3.5864
===== ROUND 2 START! =====

CLIENTS ['24' '1' '15' '30' '10' '26' '11' '27' '28' '34'] ARE SELECTED!

TRAIN WITH SP!



 10%|█         | 1/10 [00:28<04:16, 28.55s/it]

=== Client 24 Finished Training 6299 samples ===
client:24 | Train Acc:25.24 | Train Loss:2.7434


 20%|██        | 2/10 [00:42<02:39, 19.93s/it]

=== Client 1 Finished Training 3071 samples ===
client:1 | Train Acc:16.80 | Train Loss:2.9020


 30%|███       | 3/10 [00:56<01:59, 17.04s/it]

=== Client 15 Finished Training 2987 samples ===
client:15 | Train Acc:28.59 | Train Loss:2.6623


 40%|████      | 4/10 [01:08<01:31, 15.20s/it]

=== Client 30 Finished Training 2722 samples ===
client:30 | Train Acc:28.36 | Train Loss:2.8941


 50%|█████     | 5/10 [01:35<01:37, 19.45s/it]

=== Client 10 Finished Training 5915 samples ===
client:10 | Train Acc:35.28 | Train Loss:2.5510


 60%|██████    | 6/10 [02:00<01:25, 21.49s/it]

=== Client 26 Finished Training 6253 samples ===
client:26 | Train Acc:33.78 | Train Loss:2.4986


 70%|███████   | 7/10 [02:12<00:55, 18.39s/it]

=== Client 11 Finished Training 2639 samples ===
client:11 | Train Acc:39.79 | Train Loss:2.5253


 80%|████████  | 8/10 [02:41<00:43, 21.65s/it]

=== Client 27 Finished Training 6294 samples ===
client:27 | Train Acc:37.38 | Train Loss:2.3487


 90%|█████████ | 9/10 [03:09<00:23, 23.61s/it]

=== Client 28 Finished Training 6100 samples ===
client:28 | Train Acc:13.43 | Train Loss:3.3082


                                               

=== Client 34 Finished Training 3002 samples ===
client:34 | Train Acc:33.31 | Train Loss:2.7297
TEST WITH SP!

client:24 | Test Acc:25.39 | Test Loss:2.8300
client:1 | Test Acc:17.95 | Test Loss:2.8534
client:15 | Test Acc:29.03 | Test Loss:2.5926
client:30 | Test Acc:29.81 | Test Loss:2.8809
client:10 | Test Acc:36.38 | Test Loss:2.4991
client:26 | Test Acc:33.10 | Test Loss:2.5741
client:11 | Test Acc:41.25 | Test Loss:2.6166


                                      

client:27 | Test Acc:39.63 | Test Loss:2.2591
client:28 | Test Acc:13.64 | Test Loss:3.3385
client:34 | Test Acc:32.46 | Test Loss:2.8708


  0%|          | 0/10 [00:00<?, ?it/s]

Global Test Result | Acc:33.03, Loss:2.6080
===== ROUND 3 START! =====

CLIENTS ['17' '21' '30' '2' '26' '6' '4' '15' '1' '24'] ARE SELECTED!

TRAIN WITH SP!



 10%|█         | 1/10 [00:13<02:02, 13.66s/it]

=== Client 17 Finished Training 3008 samples ===
client:17 | Train Acc:51.36 | Train Loss:1.8820


 20%|██        | 2/10 [00:27<01:51, 13.91s/it]

=== Client 21 Finished Training 3099 samples ===
client:21 | Train Acc:51.44 | Train Loss:1.8551


 30%|███       | 3/10 [00:40<01:32, 13.25s/it]

=== Client 30 Finished Training 2722 samples ===
client:30 | Train Acc:8.67 | Train Loss:4.9199


 40%|████      | 4/10 [01:02<01:41, 16.85s/it]

=== Client 2 Finished Training 4941 samples ===
client:2 | Train Acc:38.47 | Train Loss:2.2113


 50%|█████     | 5/10 [01:30<01:44, 20.99s/it]

=== Client 26 Finished Training 6253 samples ===
client:26 | Train Acc:59.19 | Train Loss:1.5225


 60%|██████    | 6/10 [01:53<01:26, 21.54s/it]

=== Client 6 Finished Training 4982 samples ===
client:6 | Train Acc:58.17 | Train Loss:1.6188


 70%|███████   | 7/10 [02:16<01:05, 21.90s/it]

=== Client 4 Finished Training 4955 samples ===
client:4 | Train Acc:41.47 | Train Loss:2.1089


 80%|████████  | 8/10 [02:29<00:38, 19.27s/it]

=== Client 15 Finished Training 2987 samples ===
client:15 | Train Acc:53.33 | Train Loss:1.8182


 90%|█████████ | 9/10 [02:43<00:17, 17.59s/it]

=== Client 1 Finished Training 3071 samples ===
client:1 | Train Acc:52.46 | Train Loss:1.8878


                                               

=== Client 24 Finished Training 6299 samples ===
client:24 | Train Acc:54.18 | Train Loss:1.7040
TEST WITH SP!

client:17 | Test Acc:45.64 | Test Loss:2.0600
client:21 | Test Acc:56.62 | Test Loss:1.7875
client:30 | Test Acc:7.69 | Test Loss:4.9651
client:2 | Test Acc:40.14 | Test Loss:2.2371
client:26 | Test Acc:57.95 | Test Loss:1.5885
client:6 | Test Acc:62.39 | Test Loss:1.5591
client:4 | Test Acc:41.61 | Test Loss:2.0569


                                      

client:15 | Test Acc:51.32 | Test Loss:1.8576
client:1 | Test Acc:51.85 | Test Loss:1.8807
client:24 | Test Acc:54.44 | Test Loss:1.8187


  0%|          | 0/10 [00:00<?, ?it/s]

Global Test Result | Acc:47.74, Loss:1.9760
===== ROUND 4 START! =====

CLIENTS ['0' '32' '5' '7' '13' '9' '14' '28' '6' '10'] ARE SELECTED!

TRAIN WITH SP!



 10%|█         | 1/10 [00:12<01:53, 12.66s/it]

=== Client 0 Finished Training 2754 samples ===
client:0 | Train Acc:13.11 | Train Loss:4.3255


 20%|██        | 2/10 [00:26<01:46, 13.33s/it]

=== Client 32 Finished Training 3022 samples ===
client:32 | Train Acc:62.84 | Train Loss:1.3215


 30%|███       | 3/10 [00:48<01:59, 17.10s/it]

=== Client 5 Finished Training 4751 samples ===
client:5 | Train Acc:59.67 | Train Loss:1.4722


 40%|████      | 4/10 [01:14<02:05, 20.92s/it]

=== Client 7 Finished Training 5900 samples ===
client:7 | Train Acc:59.95 | Train Loss:1.4621


 50%|█████     | 5/10 [01:26<01:28, 17.74s/it]

=== Client 13 Finished Training 3010 samples ===
client:13 | Train Acc:44.39 | Train Loss:2.0180


 60%|██████    | 6/10 [01:51<01:20, 20.08s/it]

=== Client 9 Finished Training 5974 samples ===
client:9 | Train Acc:62.89 | Train Loss:1.3500


 70%|███████   | 7/10 [02:04<00:53, 17.86s/it]

=== Client 14 Finished Training 2946 samples ===
client:14 | Train Acc:25.08 | Train Loss:3.4783


 80%|████████  | 8/10 [02:32<00:41, 20.99s/it]

=== Client 28 Finished Training 6100 samples ===
client:28 | Train Acc:29.49 | Train Loss:2.7702


 90%|█████████ | 9/10 [02:55<00:21, 21.52s/it]

=== Client 6 Finished Training 4982 samples ===
client:6 | Train Acc:55.48 | Train Loss:1.5217


                                               

=== Client 10 Finished Training 5915 samples ===
client:10 | Train Acc:60.51 | Train Loss:1.4703
TEST WITH SP!

client:0 | Test Acc:16.77 | Test Loss:4.1923
client:32 | Test Acc:64.16 | Test Loss:1.3131
client:5 | Test Acc:57.54 | Test Loss:1.6193
client:7 | Test Acc:57.51 | Test Loss:1.5976
client:13 | Test Acc:44.61 | Test Loss:2.0927
client:9 | Test Acc:64.44 | Test Loss:1.3479
client:14 | Test Acc:27.73 | Test Loss:3.5898


                                      

client:28 | Test Acc:27.43 | Test Loss:2.8468
client:6 | Test Acc:60.96 | Test Loss:1.4697
client:10 | Test Acc:63.92 | Test Loss:1.4333


  0%|          | 0/10 [00:00<?, ?it/s]

Global Test Result | Acc:54.79, Loss:1.7063
===== ROUND 5 START! =====

CLIENTS ['20' '35' '10' '5' '28' '31' '24' '1' '2' '4'] ARE SELECTED!

TRAIN WITH SP!



 10%|█         | 1/10 [00:14<02:06, 14.01s/it]

=== Client 20 Finished Training 3102 samples ===
client:20 | Train Acc:69.47 | Train Loss:1.1170


 20%|██        | 2/10 [00:27<01:50, 13.83s/it]

=== Client 35 Finished Training 3030 samples ===
client:35 | Train Acc:60.40 | Train Loss:1.4076


 30%|███       | 3/10 [00:54<02:18, 19.76s/it]

=== Client 10 Finished Training 5915 samples ===
client:10 | Train Acc:61.35 | Train Loss:1.3495


 40%|████      | 4/10 [01:16<02:02, 20.49s/it]

=== Client 5 Finished Training 4751 samples ===
client:5 | Train Acc:60.09 | Train Loss:1.3976


 50%|█████     | 5/10 [01:43<01:55, 23.12s/it]

=== Client 28 Finished Training 6100 samples ===
client:28 | Train Acc:44.11 | Train Loss:2.0458


 60%|██████    | 6/10 [01:56<01:17, 19.42s/it]

=== Client 31 Finished Training 2677 samples ===
client:31 | Train Acc:53.34 | Train Loss:1.8019


 70%|███████   | 7/10 [02:24<01:07, 22.42s/it]

=== Client 24 Finished Training 6299 samples ===
client:24 | Train Acc:64.28 | Train Loss:1.2556


 80%|████████  | 8/10 [02:38<00:39, 19.73s/it]

=== Client 1 Finished Training 3071 samples ===
client:1 | Train Acc:63.79 | Train Loss:1.3182


 90%|█████████ | 9/10 [03:01<00:20, 20.53s/it]

=== Client 2 Finished Training 4941 samples ===
client:2 | Train Acc:58.79 | Train Loss:1.4277


                                               

=== Client 4 Finished Training 4955 samples ===
client:4 | Train Acc:66.84 | Train Loss:1.1850
TEST WITH SP!

client:20 | Test Acc:58.87 | Test Loss:1.6466
client:35 | Test Acc:58.84 | Test Loss:1.6052
client:10 | Test Acc:61.23 | Test Loss:1.3883
client:5 | Test Acc:59.96 | Test Loss:1.5394
client:28 | Test Acc:46.44 | Test Loss:1.9652
client:31 | Test Acc:50.81 | Test Loss:2.1387
client:24 | Test Acc:61.21 | Test Loss:1.4043


                                      

client:1 | Test Acc:62.96 | Test Loss:1.4598
client:2 | Test Acc:61.47 | Test Loss:1.4139
client:4 | Test Acc:66.96 | Test Loss:1.2063


  0%|          | 0/10 [00:00<?, ?it/s]

Global Test Result | Acc:57.98, Loss:1.5063
===== ROUND 6 START! =====

CLIENTS ['13' '28' '34' '15' '24' '35' '20' '7' '0' '31'] ARE SELECTED!

TRAIN WITH SP!



 10%|█         | 1/10 [00:13<02:03, 13.74s/it]

=== Client 13 Finished Training 3010 samples ===
client:13 | Train Acc:68.31 | Train Loss:1.1201


 20%|██        | 2/10 [00:41<02:55, 21.96s/it]

=== Client 28 Finished Training 6100 samples ===
client:28 | Train Acc:52.15 | Train Loss:1.6300


 30%|███       | 3/10 [00:55<02:06, 18.13s/it]

=== Client 34 Finished Training 3002 samples ===
client:34 | Train Acc:76.88 | Train Loss:0.8118


 40%|████      | 4/10 [01:08<01:38, 16.34s/it]

=== Client 15 Finished Training 2987 samples ===
client:15 | Train Acc:63.31 | Train Loss:1.3013


 50%|█████     | 5/10 [01:37<01:43, 20.78s/it]

=== Client 24 Finished Training 6299 samples ===
client:24 | Train Acc:66.34 | Train Loss:1.1424


 60%|██████    | 6/10 [01:51<01:13, 18.40s/it]

=== Client 35 Finished Training 3030 samples ===
client:35 | Train Acc:68.55 | Train Loss:1.1048


 70%|███████   | 7/10 [02:05<00:50, 16.97s/it]

=== Client 20 Finished Training 3102 samples ===
client:20 | Train Acc:72.05 | Train Loss:0.9982


 80%|████████  | 8/10 [02:31<00:40, 20.13s/it]

=== Client 7 Finished Training 5900 samples ===
client:7 | Train Acc:67.12 | Train Loss:1.1416


 90%|█████████ | 9/10 [02:44<00:17, 17.74s/it]

=== Client 0 Finished Training 2754 samples ===
client:0 | Train Acc:72.26 | Train Loss:1.0151


                                               

=== Client 31 Finished Training 2677 samples ===
client:31 | Train Acc:72.51 | Train Loss:1.0594
TEST WITH SP!

client:13 | Test Acc:66.18 | Test Loss:1.2605
client:28 | Test Acc:49.78 | Test Loss:1.7803
client:34 | Test Acc:71.30 | Test Loss:1.0838
client:15 | Test Acc:58.65 | Test Loss:1.6086
client:24 | Test Acc:65.16 | Test Loss:1.3138
client:35 | Test Acc:65.22 | Test Loss:1.3395
client:20 | Test Acc:61.69 | Test Loss:1.6015


                                      

client:7 | Test Acc:62.61 | Test Loss:1.3296
client:0 | Test Acc:68.67 | Test Loss:1.2181
client:31 | Test Acc:68.08 | Test Loss:1.3785


  0%|          | 0/10 [00:00<?, ?it/s]

Global Test Result | Acc:60.85, Loss:1.3720
===== ROUND 7 START! =====

CLIENTS ['34' '24' '18' '6' '14' '12' '11' '35' '16' '33'] ARE SELECTED!

TRAIN WITH SP!



 10%|█         | 1/10 [00:13<02:03, 13.70s/it]

=== Client 34 Finished Training 3002 samples ===
client:34 | Train Acc:78.58 | Train Loss:0.7417


 20%|██        | 2/10 [00:42<03:00, 22.52s/it]

=== Client 24 Finished Training 6299 samples ===
client:24 | Train Acc:62.72 | Train Loss:1.2631


 30%|███       | 3/10 [00:56<02:11, 18.74s/it]

=== Client 18 Finished Training 3098 samples ===
client:18 | Train Acc:74.82 | Train Loss:0.8806


 40%|████      | 4/10 [01:19<02:01, 20.33s/it]

=== Client 6 Finished Training 4982 samples ===
client:6 | Train Acc:62.85 | Train Loss:1.2481


 50%|█████     | 5/10 [01:32<01:29, 17.85s/it]

=== Client 14 Finished Training 2946 samples ===
client:14 | Train Acc:51.56 | Train Loss:1.6610


 60%|██████    | 6/10 [01:58<01:22, 20.67s/it]

=== Client 12 Finished Training 5760 samples ===
client:12 | Train Acc:70.78 | Train Loss:1.0042


 70%|███████   | 7/10 [02:10<00:53, 17.83s/it]

=== Client 11 Finished Training 2639 samples ===
client:11 | Train Acc:73.74 | Train Loss:0.9414


 80%|████████  | 8/10 [02:24<00:33, 16.57s/it]

=== Client 35 Finished Training 3030 samples ===
client:35 | Train Acc:65.12 | Train Loss:1.2289


 90%|█████████ | 9/10 [02:38<00:15, 15.63s/it]

=== Client 16 Finished Training 2965 samples ===
client:16 | Train Acc:69.98 | Train Loss:1.0317


                                               

=== Client 33 Finished Training 3037 samples ===
client:33 | Train Acc:73.39 | Train Loss:0.9280
TEST WITH SP!

client:34 | Test Acc:71.01 | Test Loss:1.0699
client:24 | Test Acc:60.93 | Test Loss:1.4519
client:18 | Test Acc:77.40 | Test Loss:0.9413
client:6 | Test Acc:64.71 | Test Loss:1.3203
client:14 | Test Acc:52.80 | Test Loss:1.8738
client:12 | Test Acc:69.44 | Test Loss:1.0750
client:11 | Test Acc:71.95 | Test Loss:1.1368
client:35 | Test Acc:60.87 | Test Loss:1.5108


                                      

client:16 | Test Acc:69.03 | Test Loss:1.1130
client:33 | Test Acc:72.13 | Test Loss:1.0547


  0%|          | 0/10 [00:00<?, ?it/s]

Global Test Result | Acc:63.43, Loss:1.2857
===== ROUND 8 START! =====

CLIENTS ['19' '8' '2' '11' '13' '24' '17' '1' '28' '23'] ARE SELECTED!

TRAIN WITH SP!



 10%|█         | 1/10 [00:14<02:07, 14.15s/it]

=== Client 19 Finished Training 3110 samples ===
client:19 | Train Acc:71.77 | Train Loss:0.9477


 20%|██        | 2/10 [00:40<02:49, 21.16s/it]

=== Client 8 Finished Training 5711 samples ===
client:8 | Train Acc:59.17 | Train Loss:1.3654


 30%|███       | 3/10 [01:01<02:27, 21.13s/it]

=== Client 2 Finished Training 4941 samples ===
client:2 | Train Acc:77.51 | Train Loss:0.7832


 40%|████      | 4/10 [01:13<01:44, 17.47s/it]

=== Client 11 Finished Training 2639 samples ===
client:11 | Train Acc:74.38 | Train Loss:0.8958


 50%|█████     | 5/10 [01:26<01:20, 16.10s/it]

=== Client 13 Finished Training 3010 samples ===
client:13 | Train Acc:46.11 | Train Loss:2.1123


 60%|██████    | 6/10 [01:55<01:21, 20.32s/it]

=== Client 24 Finished Training 6299 samples ===
client:24 | Train Acc:57.41 | Train Loss:1.4392


 70%|███████   | 7/10 [02:09<00:54, 18.15s/it]

=== Client 17 Finished Training 3008 samples ===
client:17 | Train Acc:69.75 | Train Loss:1.0268


 80%|████████  | 8/10 [02:23<00:33, 16.84s/it]

=== Client 1 Finished Training 3071 samples ===
client:1 | Train Acc:75.28 | Train Loss:0.8570


 90%|█████████ | 9/10 [02:50<00:20, 20.26s/it]

=== Client 28 Finished Training 6100 samples ===
client:28 | Train Acc:68.13 | Train Loss:1.0643


                                               

=== Client 23 Finished Training 3216 samples ===
client:23 | Train Acc:75.40 | Train Loss:0.8414
TEST WITH SP!

client:19 | Test Acc:65.17 | Test Loss:1.1704
client:8 | Test Acc:58.36 | Test Loss:1.5119
client:2 | Test Acc:74.37 | Test Loss:0.8976
client:11 | Test Acc:68.65 | Test Loss:1.1138
client:13 | Test Acc:49.27 | Test Loss:1.9840
client:24 | Test Acc:54.58 | Test Loss:1.6731
client:17 | Test Acc:65.12 | Test Loss:1.2581


                                      

client:1 | Test Acc:71.51 | Test Loss:1.0575
client:28 | Test Acc:65.75 | Test Loss:1.1651
client:23 | Test Acc:70.30 | Test Loss:1.0900


  0%|          | 0/10 [00:00<?, ?it/s]

Global Test Result | Acc:64.18, Loss:1.2511
===== ROUND 9 START! =====

CLIENTS ['4' '28' '9' '2' '35' '31' '27' '12' '20' '11'] ARE SELECTED!

TRAIN WITH SP!



 10%|█         | 1/10 [00:22<03:23, 22.62s/it]

=== Client 4 Finished Training 4955 samples ===
client:4 | Train Acc:73.82 | Train Loss:0.8717


 20%|██        | 2/10 [00:50<03:24, 25.59s/it]

=== Client 28 Finished Training 6100 samples ===
client:28 | Train Acc:70.41 | Train Loss:0.9833


 30%|███       | 3/10 [01:17<03:04, 26.33s/it]

=== Client 9 Finished Training 5974 samples ===
client:9 | Train Acc:72.31 | Train Loss:0.9238


 40%|████      | 4/10 [01:39<02:28, 24.76s/it]

=== Client 2 Finished Training 4941 samples ===
client:2 | Train Acc:80.15 | Train Loss:0.6957


 50%|█████     | 5/10 [01:53<01:43, 20.80s/it]

=== Client 35 Finished Training 3030 samples ===
client:35 | Train Acc:72.84 | Train Loss:0.8990


 60%|██████    | 6/10 [02:03<01:08, 17.16s/it]

=== Client 31 Finished Training 2677 samples ===
client:31 | Train Acc:81.73 | Train Loss:0.6500


 70%|███████   | 7/10 [02:32<01:02, 20.90s/it]

=== Client 27 Finished Training 6294 samples ===
client:27 | Train Acc:70.54 | Train Loss:0.9745


 80%|████████  | 8/10 [02:58<00:45, 22.55s/it]

=== Client 12 Finished Training 5760 samples ===
client:12 | Train Acc:74.24 | Train Loss:0.8539


 90%|█████████ | 9/10 [03:12<00:19, 19.87s/it]

=== Client 20 Finished Training 3102 samples ===
client:20 | Train Acc:74.05 | Train Loss:0.8603


                                               

=== Client 11 Finished Training 2639 samples ===
client:11 | Train Acc:77.79 | Train Loss:0.7575
TEST WITH SP!

client:4 | Test Acc:73.21 | Test Loss:0.9330
client:28 | Test Acc:69.96 | Test Loss:1.0739
client:9 | Test Acc:72.00 | Test Loss:0.9720
client:2 | Test Acc:76.88 | Test Loss:0.8055
client:35 | Test Acc:66.09 | Test Loss:1.2309
client:31 | Test Acc:75.90 | Test Loss:0.9789


                                      

client:27 | Test Acc:72.21 | Test Loss:1.0398
client:12 | Test Acc:71.60 | Test Loss:0.9595
client:20 | Test Acc:64.23 | Test Loss:1.4727
client:11 | Test Acc:72.28 | Test Loss:0.9858


  0%|          | 0/10 [00:00<?, ?it/s]

Global Test Result | Acc:66.56, Loss:1.1702
===== ROUND 10 START! =====

CLIENTS ['13' '32' '34' '8' '31' '10' '29' '17' '33' '5'] ARE SELECTED!

TRAIN WITH SP!



 10%|█         | 1/10 [00:13<02:03, 13.71s/it]

=== Client 13 Finished Training 3010 samples ===
client:13 | Train Acc:28.77 | Train Loss:4.1838


 20%|██        | 2/10 [00:27<01:49, 13.70s/it]

=== Client 32 Finished Training 3022 samples ===
client:32 | Train Acc:82.13 | Train Loss:0.6281


 30%|███       | 3/10 [00:41<01:35, 13.67s/it]

=== Client 34 Finished Training 3002 samples ===
client:34 | Train Acc:80.28 | Train Loss:0.6359


 40%|████      | 4/10 [01:06<01:51, 18.50s/it]

=== Client 8 Finished Training 5711 samples ===
client:8 | Train Acc:74.44 | Train Loss:0.8486


 50%|█████     | 5/10 [01:19<01:21, 16.21s/it]

=== Client 31 Finished Training 2677 samples ===
client:31 | Train Acc:80.99 | Train Loss:0.6651


 60%|██████    | 6/10 [01:45<01:19, 19.84s/it]

=== Client 10 Finished Training 5915 samples ===
client:10 | Train Acc:71.87 | Train Loss:0.9386


 70%|███████   | 7/10 [01:58<00:52, 17.35s/it]

=== Client 29 Finished Training 2676 samples ===
client:29 | Train Acc:81.50 | Train Loss:0.6290


 80%|████████  | 8/10 [02:11<00:32, 16.17s/it]

=== Client 17 Finished Training 3008 samples ===
client:17 | Train Acc:71.51 | Train Loss:0.9594


 90%|█████████ | 9/10 [02:25<00:15, 15.43s/it]

=== Client 33 Finished Training 3037 samples ===
client:33 | Train Acc:77.38 | Train Loss:0.7628


                                               

=== Client 5 Finished Training 4751 samples ===
client:5 | Train Acc:78.91 | Train Loss:0.7091
TEST WITH SP!

client:13 | Test Acc:26.53 | Test Loss:4.3846
client:32 | Test Acc:77.46 | Test Loss:0.8324
client:34 | Test Acc:72.46 | Test Loss:0.9960
client:8 | Test Acc:72.76 | Test Loss:1.0034
client:31 | Test Acc:76.55 | Test Loss:0.9238
client:10 | Test Acc:70.66 | Test Loss:1.0772
client:29 | Test Acc:73.87 | Test Loss:0.9328
client:17 | Test Acc:67.44 | Test Loss:1.2062


                                      

client:33 | Test Acc:75.00 | Test Loss:0.9275
client:5 | Test Acc:75.79 | Test Loss:0.9208




Global Test Result | Acc:67.20, Loss:1.1167


In [49]:
server.global_model.state_dict()['fc1.weight']

tensor([[-0.0308,  0.0339,  0.0247,  ..., -0.0260,  0.0069,  0.0245],
        [ 0.0048, -0.0336,  0.0300,  ...,  0.0076, -0.0006,  0.0174],
        [ 0.0171, -0.0087, -0.0146,  ...,  0.0111,  0.0015, -0.0210],
        ...,
        [ 0.0073,  0.0097, -0.0242,  ..., -0.0315, -0.0354,  0.0192],
        [-0.0113, -0.0231,  0.0196,  ..., -0.0013, -0.0307,  0.0232],
        [ 0.0082, -0.0187, -0.0180,  ..., -0.0172, -0.0092, -0.0107]])

In [50]:
server.global_test()

Global Test Result | Acc:40.58, Loss:2.4326


In [52]:
server.clients['27'].local_test()

client:27 | Test Acc:49.37 | Test Loss:1.9029


In [53]:
server.clients['27'].local_train()

=== Client 27 Finished Training 6294 samples ===
client:27 | Train Acc:62.42 | Train Loss:1.3118


In [54]:
server.clients['27'].local_test()

client:27 | Test Acc:60.93 | Test Loss:1.3687


In [55]:
server.global_model.state_dict()['fc1.weight']

tensor([[-0.0308,  0.0339,  0.0247,  ..., -0.0260,  0.0069,  0.0245],
        [ 0.0048, -0.0336,  0.0300,  ...,  0.0076, -0.0006,  0.0174],
        [ 0.0171, -0.0087, -0.0146,  ...,  0.0111,  0.0015, -0.0210],
        ...,
        [ 0.0073,  0.0097, -0.0242,  ..., -0.0315, -0.0354,  0.0192],
        [-0.0113, -0.0231,  0.0196,  ..., -0.0013, -0.0307,  0.0232],
        [ 0.0082, -0.0187, -0.0180,  ..., -0.0172, -0.0092, -0.0107]])

In [56]:
server.global_test()

Global Test Result | Acc:40.58, Loss:2.4326


In [59]:
server.clients['27'].optimizer.param_groups[0]['lr']

0.01

# CLIENT

In [2]:
from torch.optim import SGD

class Client():
    def __init__(self, client_id:str, model:nn.Module, data_info:dict=None, device:str=cfg.DEVICE):
        self.id = client_id
        #self.cfg = cfg
        
        self.__model = None
        self.device = device
        
        self.train_info, self.test_info = data_info['train'], data_info['test'] # 함수화하기
        self.trainset, self.testset = FEMNIST(self.train_info), FEMNIST(self.test_info)
        
    @property
    def model(self):             
        return self.__model

    @model.setter
    def model(self, model):
        self.__model = model
    
    def __len__(self):
        return len(self.trainset)
    
    def setup(self):
        self.train_loader = DataLoader(self.trainset, batch_size=16, shuffle=True)
        self.test_loader = DataLoader(self.testset, batch_size=16, shuffle=False)
        self.optimizer = SGD(self.model.parameters(), lr=0.01)     # TODO: utils.get_optimizer(cfg['optim']:str)
        self.criterion = nn.CrossEntropyLoss()                       # TODO: utils.get_loss(cfg['loss']:str)
        self.epochs = 10
    
    def local_train(self)->None:
        proc = os.getpid()
        self.model.train()
        self.model.to(self.device)
        # TRAINING
        for epoch in range(self.epochs):
            for idx, batch in enumerate(self.train_loader):
                self.optimizer.zero_grad()
                X, Y = batch
                X, Y = X.to(self.device), Y.to(self.device)
                pred = self.model(X)
                loss = self.criterion(pred, Y)
                loss.backward()
                self.optimizer.step()
                if "cuda" in self.device : torch.cuda.empty_cache()
        # TESTING
        self.model.eval()
        self.model.to(self.device)
        with torch.no_grad():
            loss_trace, result_pred, result_anno = [], [], []
            for idx, batch in enumerate(self.train_loader):
                X, Y = batch
                X, Y = X.to(self.device), Y.to(self.device)
                pred = self.model(X)
                loss = self.criterion(pred, Y)
                loss_trace.append(loss.to('cpu').detach().numpy())
                pred_np  = pred.to('cpu').detach().numpy()
                pred_np  = np.argmax(pred_np, axis=1).squeeze()
                Y_np     = Y.to('cpu').detach().numpy().reshape(-1, 1).squeeze()
                result_pred = np.hstack((result_pred, pred_np))
                result_anno = np.hstack((result_anno, Y_np))
                if "cuda" in self.device : torch.cuda.empty_cache()
            train_acc = metrics.accuracy_score(y_true=result_anno, y_pred=result_pred)
            train_loss = np.average(loss_trace)
            self.model.to('cpu')
        
        print(f'=== Process ID: {proc} | Client {self.id} Finished Training {len(self)} samples ===')
        print(f'client:{self.id} | Train Acc:{train_acc*100:.2f} | Train Loss:{train_loss:.4f}')
    
    def local_test(self):
        self.model.eval()
        self.model.to(self.device)
        with torch.no_grad():
            loss_trace, result_pred, result_anno = [], [], []
            for idx, batch in enumerate(self.test_loader):
                X, Y = batch
                X, Y = X.to(self.device), Y.to(self.device)
                pred = self.model(X)
                loss = self.criterion(pred, Y)
                loss_trace.append(loss.to('cpu').detach().numpy())
                pred_np  = pred.to('cpu').detach().numpy()
                pred_np  = np.argmax(pred_np, axis=1).squeeze()
                Y_np     = Y.to('cpu').detach().numpy().reshape(-1, 1).squeeze()
                result_pred = np.hstack((result_pred, pred_np))
                result_anno = np.hstack((result_anno, Y_np))
                
                if "cuda" in self.device : torch.cuda.empty_cache()
                
            test_acc = metrics.accuracy_score(y_true=result_anno, y_pred=result_pred)
            test_loss = np.average(loss_trace)
            print(f'client:{self.id} | Test Acc:{test_acc*100:.2f} | Test Loss:{test_loss:.4f}')
            self.model.to('cpu')

# SERVER

In [3]:
class Server():
    def __init__(self, DM_dict:dict, algorithm:str=None):
        self.train_DM = DM_dict['train']
        self.test_DM = DM_dict['test']
        
        self.clients = None
        self.device = cfg.DEVICE
        
        self.global_model = Net()
        
        self.criterion = nn.CrossEntropyLoss()              # TODO: utils.get_loss(cfg['loss']:str)
        
        self.Algorithm = FedAvg.FedAvg                      # FedAVG 같은 aggrrgation method 들어감 TODO: utils.get_algortihm() 작성
        self.received_models = None                         # Client.upload_model() 결과가 여기 들어감
        
        self.mp_flag = True

    def setup(self):
        self.clients = self.create_clients()
        self.data = FEMNIST(self.test_DM.get_global_testset())
        self.dataloader = DataLoader(self.data, batch_size=256, shuffle=False)
        
        self.transmit_model()
        self.setup_clients()
        
        
    def create_clients(self, n_users:int=100):
        self.user_ids = self.test_DM.users
        self.user_ids = np.random.choice(self.user_ids, n_users, replace=False)
        clients = {}
        for user in self.user_ids:
            data_info = {'train':self.train_DM.get_user_info(user),\
                         'test':self.test_DM.get_user_info(user)}
            clients[user] = Client(client_id=user, model=self.global_model, data_info=data_info)
        return clients
    
    def setup_clients(self)->None:
        for k, client in tqdm(enumerate(self.clients), leave=False):
            self.clients[client].setup()
    
    def transmit_model(self, sampled_clients:list=None)->None:
        if sampled_clients == None:
            for client in tqdm(self.clients, leave=False):
                self.clients[client].model = copy.deepcopy(self.global_model)
        else:
            for client in tqdm(sampled_clients, leave=False):
                self.clients[client].model = copy.deepcopy(self.global_model)

        
    def sample_clients(self, n_participant:int=50)->np.array:
        assert n_participant <= len(self.user_ids), "Check 'n_participant <= len(self.clients)'"
        return np.random.choice(self.user_ids, n_participant, replace=False) # 입력된 수의 유저를 추출해서 반환

    def train_selected_clients(self, sampled_clients:list)->None:
        total_sample = 0
        for client in tqdm(sampled_clients, leave=False):
            self.clients[client].local_train()
            total_sample += len(self.clients[client])

    def mp_train_selected_clients(self, client:str)->None:
        self.clients[client].local_train()
        n_sample = len(self.clients[client])
        return n_sample
    
    def test_selected_models(self, sampled_clients):
        for client in sampled_clients:
            self.clients[client].local_test()

    def mp_test_selected_models(self, client):
        self.clients[client].local_test()
    
    def average_model(self, sampled_clients, coefficients):
        averaged_weights = OrderedDict()
        for it, client in tqdm(enumerate(sampled_clients), leave=False):
            local_weights = self.clients[client].model.state_dict()
            for key in self.global_model.state_dict().keys():
                if it == 0:
                    averaged_weights[key] = coefficients[it] * local_weights[key]
                else:
                    averaged_weights[key] += coefficients[it] * local_weights[key]
        self.global_model.load_state_dict(averaged_weights)
    
    def update_model(self, train_result:dict, layers:list=None):
        self.received_models, num_samples = [], []
        for result in train_result:
            self.received_models.append(result['model'])
            num_samples.append(result['num_sample'])
        state = self.Algorithm(self.received_models, num_samples, layers)
        self.global_model.load_state_dict(state)

    def train_federated_model(self):
        sampled_clients = self.sample_clients()
        
        if self.mp_flag:
            with pool.ThreadPool(processes=cpu_count() - 1) as workhorse:
                selected_total_size = workhorse.map(self.mp_train_selected_clients, sampled_clients)
            selected_total_size = sum(selected_total_size)
        else:
            selected_total_size = self.train_selected_clients(sampled_clients)

        if self.mp_flag:
            with pool.ThreadPool(processes=cpu_count() - 1) as workhorse:
                workhorse.map(self.mp_test_selected_models, sampled_clients)
        else:
            self.test_selected_models(sampled_clients)
        
        mixing_coefficients = [len(self.clients[client]) / selected_total_size for client in sampled_clients]
        
        self.average_model(sampled_clients, mixing_coefficients)
        
    def global_test(self):
        self.global_model.eval()
        self.global_model.to(self.device)
        
        with torch.no_grad():
            loss_trace, result_pred, result_anno = [], [], []
            for idx, batch in enumerate(self.dataloader):
                X, Y = batch
                X, Y = X.to(self.device), Y.to(self.device)
                pred = self.global_model(X)
                loss = self.criterion(pred, Y)
                loss_trace.append(loss.to('cpu').detach().numpy())
                pred_np  = pred.to('cpu').detach().numpy()
                pred_np  = np.argmax(pred_np, axis=1).squeeze()
                Y_np     = Y.to('cpu').detach().numpy().reshape(-1, 1).squeeze()
                result_pred = np.hstack((result_pred, pred_np))
                result_anno = np.hstack((result_anno, Y_np))
            self.acc = metrics.accuracy_score(y_true=result_anno, y_pred=result_pred)
            self.test_loss = np.average(loss_trace)
            print(f'Global Test Result | Acc:{self.acc*100:.2f}, Loss:{self.test_loss:.4f}')
            self.global_model.to(self.device)

In [4]:
PATH = cfg.DATAPATH['femnist']
    
file_dict = get_files(PATH)
    
TRAIN_DM = DataManager(file_dict['train'], is_train=True)
TEST_DM = DataManager(file_dict['test'], is_train=False)
DM_dict = {'train':TRAIN_DM,
            'test':TEST_DM}
print("DATA READY")

DATA READY


In [None]:
server = Server(DM_dict)
#server.create_clients()
server.setup()
for i in range(100):
    server.train_federated_model()
    server.global_test()


In [None]:
file = server.train_DM.files[0]
with open(file) as f:
    data = json.load(f)
data


In [41]:
np.shape(data['user_data'][data['users'][0]]['x'])

(141, 784)

In [67]:
class DataManager():
    def __init__(self, files:list, is_train:bool=True):
        self.files = files
        self.is_train = is_train
        self.users, self.data = [], {}
        if not self.is_train:
            self.global_test_data = {'x':[], 'y':[]}
        
        for idx, file in enumerate(self.files):
            idx = str(idx)
            self.data[idx] = {'x':[], 'y':[]}
            with open(file) as f:
                data = json.load(f)
                self.users.append(idx)

                for user in data['users']:              # 각 유저의 data 저장
                    self.data[idx]['x'] = self.data[idx]['x'] + data['user_data'][user]['x']
                    self.data[idx]['y'] = self.data[idx]['y'] + data['user_data'][user]['y']
                    
                if not self.is_train:                   # test dataset인 경우 global evaluation 위해서 모든 데이터셋 저장
                    for user in data['users']:
                        self.global_test_data['x'] = self.global_test_data['x'] + data['user_data'][user]['x']
                        self.global_test_data['y'] = self.global_test_data['y'] + data['user_data'][user]['y']

class FEMNIST(Dataset):
    def __init__(self, data:dict):
        self.data = data
        self.data['x'] = np.array(data['x'])
        self.data['y'] = np.array(data['y'])
        
    def __getitem__(self, idx):
        #self.X = torch.tensor(self.data['x'][idx,:].reshape(1, 28, 28)).float()
        self.X = torch.tensor(self.data['x'][idx,:]).float()
        self.Y = torch.tensor(self.data['y'][idx]).long()
        return self.X, self.Y

    def __len__(self):
        return len(self.data['y'])

In [53]:
PATH = '../leaf/data/femnist/data/train'
files = [os.path.join(PATH, file) for file in os.listdir(PATH) if file.endswith('.json')]
files.sort()

DM = DataManager(files, is_train=False)

In [49]:
np.shape(DM.data['1']['x'])

(351, 784)

In [66]:
np.shape(DM.data['0']['x'])

(2754, 784)

In [68]:
dataset = FEMNIST(DM.data['0'])

In [69]:
dataset[0]

(tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
         1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
         1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
         1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
         1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
         1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
         1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
         1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
         1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
         1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
         1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
         1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
         1.0000, 1.0000, 1.0