======FairFed============

-> Init Models for clients

-> Get statistics of the dataset

-> start training

-> for each client

        > Update its weights based on local and global fairness
        
        > store it
-> SecAgg applied on the server side

In [1]:
import pandas as pd
import time
import numpy as np
import copy
import glob
import sys, os
import torch
import torch.nn as nn
import torch.utils.data as data_utils
from sklearn.metrics import precision_score, recall_score
import pickle
from sklearn import preprocessing
import torch.optim as optim
from sklearn.preprocessing import LabelEncoder
from folktables import ACSDataSource, ACSEmployment,ACSIncome
import matplotlib.pyplot as plt
import torch.nn.functional as F
from tqdm import tqdm
from sklearn.metrics import roc_curve
from sklearn.model_selection import StratifiedKFold, train_test_split
from torchvision import models
from torchsummary import summary
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import accuracy_score
import os
from scipy.stats import multivariate_normal
import torch, random, copy, os
from collections import OrderedDict
import shutil

In [2]:
# TPR
# Gender: 
# 1: Male, 0: Female
# Fairness: Male- Female
# eq-4: Global- TPR

# data preparation

In [3]:
import torch

if torch.cuda.is_available():
    print("CUDA is available.")
else:
    # CUDA is not available
    print("CUDA is not available. Running on CPU.")


CUDA is available.


In [4]:
 
class DeepNet(nn.Module):
    def __init__(self):
        super().__init__()
        # 14 : input shape
        # 9-> we have 9 columns in data 
        self.layer1 = nn.Linear(14, 512)
        self.act1 = nn.ReLU()
        self.dropout1 = nn.Dropout(p=0.5)
        self.layer2 = nn.Linear(512, 256)
        self.act2 = nn.ReLU()
        self.layer3 = nn.Linear(256, 60)
        self.act3 = nn.ReLU()
        self.output = nn.Linear(60, 1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x = self.act1(self.layer1(x))
        x = self.dropout1(x)
        x = self.act2(self.layer2(x))
        x = self.act3(self.layer3(x))
        x = self.sigmoid(self.output(x))
        return x



# assigning Models

In [5]:
class AssignModel:
    def __init__(self, global_model_base_path, selected_clients,algo="FF"):      
        self.global_model_path = f"{global_model_base_path}/global_model_{algo}.pt"
        self.selected_clients = selected_clients

    def save_global_model(self, model):
        torch.save(model.state_dict(), self.global_model_path)

    def save_client_models(self, model):
        for client_id in self.selected_clients:
            client_model_path = f"/home/chiragapandav/Downloads/Hiwi/Improving-Fairness-via-Federated-Learning/FedFB/models/client_{client_id}_model.pth"
            torch.save(model.state_dict(), client_model_path)

In [6]:
class Client:
    def __init__(self):
        self.client_id: int = None
        self.valset: DataLoader = None
        self.trainset: DataLoader = None
        self.testset: DataLoader = None
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")         
        self.model = DeepNet().to(self.device)
        self.criterion = torch.nn.BCELoss()

        self.fainess_score=0

        self.client_df = pd.DataFrame()
        self.selected_clients=[0,1,2,3]

        self.client_fairness=0
        self.global_fairness=0

        
    def get_client_local_dataset(self):
                
        temp_path_data="/home/chiragapandav/Downloads/Hiwi/Improving-Fairness-via-Federated-Learning/FedFB/data_fairFed"
        
        with open(temp_path_data+"/clients_training.pkl", "rb") as f:
            self.trainset = pickle.load(f)
            
        with open(temp_path_data+"/clients_validation.pkl", "rb") as f:
            self.valset = pickle.load(f)
         
        with open(temp_path_data+"/clients_testing_wrong.pkl", "rb") as f:
            self.testset  = pickle.load(f)

        self.trainset = self.trainset[self.client_id]
        self.valset = self.valset[self.client_id]       
        self.testset = self.testset[self.client_id]

    def get_stats(self):
        # 1 male 0 female
        total_pr_Y1_A0 = 0
        total_pr_Y1_A1 = 0

        for client_id in self.selected_clients:
            self.client_id=client_id
            self.get_client_local_dataset()
            for inputs, labels, A in self.trainset:
                pr_Y1_A1 = torch.sum((labels == 1) & (A == 1)).item()
                pr_Y1_A0 = torch.sum((labels == 1) & (A == 0)).item()
          
                total_pr_Y1_A0 += pr_Y1_A0
                total_pr_Y1_A1 += pr_Y1_A1
        temp_sum=total_pr_Y1_A0+total_pr_Y1_A1
        return [total_pr_Y1_A1/temp_sum,total_pr_Y1_A0/temp_sum]
        
    def calculate_fairness(self,y_hat, A, Y,server_acc,client_acc,len_client_data):
        # Calculate counts using torch.sum
        y_hat, A, Y = y_hat.to(self.device), A.to(self.device), Y.to(self.device)

        # 1: Male, 0: Female
        predict_Y_male = torch.sum((y_hat[(A == 1) & (Y == 1)] == 1)).item()
        predict_Y_female = torch.sum((y_hat[(A == 0) & (Y == 1)] == 1)).item()
        
        count_Y_male = torch.sum((Y[(A == 1) & (Y == 1)] == 1)).item()
        count_Y_female = torch.sum((Y[(A == 0) & (Y == 1)] == 1)).item()
    
        # Calculate probabilities
        
        if count_Y_male > 0 and count_Y_female > 0:
            prob_Y_male = predict_Y_male / count_Y_male
            prob_Y_female = predict_Y_female / count_Y_female
            fairness_score = prob_Y_male - prob_Y_female
        else:
            # Cal sever_acc from client using Eq.
            total_len_data=sum(len_client_data)   
            temp_server_acc=0
            for i in range(len(self.selected_clients)):
                temp_server_acc+=client_acc[i]*(len_client_data[i]/total_len_data)
                
                
            avg_acc_client=sum(client_acc) /len(client_acc)
            fairness_score= abs(avg_acc_client-server_acc)
            # print("Else calculate_fairness fairness_score",fairness_score)
            
        return  fairness_score

    def calculate_ser_fairness(self,y_hat, A, Y,statistics,server_acc,client_acc,len_client_data):
        # print("calculate_ser_fairness statistics",statistics)
        y_hat, A, Y = y_hat.to(self.device), A.to(self.device), Y.to(self.device)

        # for individual clients!
        #Male 
        predict_Y_male = torch.sum((y_hat[(A == 1) & (Y == 1)] == 1)).item() #get total
        #Pr(A=0, Y=1)
        count_Y_male = torch.sum((A == 1) & (Y == 1)).item() 
        #Pr(Ŷ=1|A=1, Y=1)  #Female      
        predict_Y_female = torch.sum((y_hat[(A == 0) & (Y == 1)] == 1)).item()  #get total
        #Pr(A=1, Y=1)
        count_Y_female = torch.sum((A == 0) & (Y == 1)).item()
        
        if count_Y_male>0 and count_Y_female>0:
            predict_Y_male=predict_Y_male/count_Y_male #Pr(Ŷ=1|A=0, Y=1)
            predict_Y_female=predict_Y_female/count_Y_female   #Pr(Ŷ=1|A=1, Y=1)
            temp_fairness_server=((predict_Y_male * count_Y_male)/statistics[0]) - ((predict_Y_female * count_Y_female)/statistics[1])
        else:
            # Cal sever_acc from client using Eq.
            total_len_data=sum(len_client_data)   
            temp_server_acc=0
            for i in range(len(self.selected_clients)):
                temp_server_acc+=client_acc[i]*(len_client_data[i]/total_len_data)
                
                
            avg_acc_client=sum(client_acc) /len(client_acc)
            temp_fairness_server= abs(avg_acc_client-server_acc)
            # print("Else calculate_ser_fairness",temp_fairness_server)

        # print("calculate_ser_fairness temp_fairness_server",self.client_id,temp_fairness_server)
        return temp_fairness_server


    def train(self, client_id: int, model_path,statistics,server_acc,clients_acc,len_client_data, num_epochs=5, learning_rate=0.001):
        self.client_id = client_id
        self.get_client_local_dataset()

        # Define loss function and optimizer
        criterion = nn.BCELoss()
        optimizer = optim.Adam(self.model.parameters(), lr=learning_rate)
        self.fainess_score=0
      
        self.model.load_state_dict(torch.load(model_path,map_location=self.device))
        for epoch in range(num_epochs):
            # Set the model to training mode
            self.model.train()

            for inputs, labels, sens in self.trainset:
                
                inputs, labels = inputs.to(self.device), labels.to(self.device)
                optimizer.zero_grad()
                outputs = self.model(inputs).to(self.device)
                loss = self.criterion(outputs, labels.float())
                loss.backward()
                optimizer.step()

            # Calculate and print the training accuracy for this epoch (optional)
            
            correct = 0
            total = 0
            size = 0
            loss = 0
            correct = 0
            total = 0
            predicted_labels = []
            true_labels = []
            final_fairness=[]
            final_fairness_server=[]
            
            with torch.no_grad():
                print("======Validation========")
                for inputs, labels,sens in self.valset:
                    inputs, labels = inputs.to(self.device), labels.to(self.device)
                    
                    outputs = self.model(inputs)
                    loss += self.criterion(outputs, labels)
                    predicted = outputs > 0.5

                    predicted_labels.extend(predicted.cpu().numpy())
                    true_labels.extend(labels.cpu().numpy())                   
                    
                    total += labels.size(0)
                    correct += (predicted == labels).sum().item()                   

            loss = loss / len(self.valset)        
            acc = correct/total
 
            fairNess_per_batch=self.calculate_fairness(torch.round(outputs).squeeze(), sens.squeeze(), labels.squeeze(),
                                                       server_acc,
                                                       clients_acc,
                                                      len_client_data)
            final_fairness.append(fairNess_per_batch)
            self.fainess_score=np.mean(final_fairness)

            fairNess_per_batch_server=self.calculate_ser_fairness(torch.round(outputs).squeeze(), sens.squeeze(), labels.squeeze(),
                                                                  statistics,
                                                                  server_acc,
                                                                  clients_acc,
                                                                 len_client_data)
            final_fairness_server.append(fairNess_per_batch_server)            
            ser_score=np.mean(final_fairness_server)
                        
            precision = precision_score(true_labels, predicted_labels,zero_division=0.0)            
            recall = recall_score(true_labels, predicted_labels)
            
            # accuracy = 100 * correct / total
            print(f"Epoch {epoch+1}/{num_epochs}, Val Accuracy: {acc:.5f}%")
            print(f"Epoch {epoch+1}/{num_epochs}, Val Fairness: {self.fainess_score:.5f}%")
            print(f"Epoch {epoch+1}/{num_epochs}, Val Recall: {recall:.5f}%")


        # Optionally, save the trained model parameters
        #store model
        torch.save(self.model.state_dict(), model_path)

        # Return the trained model parameters
        return list(self.model.state_dict().values()), len(self.trainset.dataset),self.model.state_dict(), self.fainess_score, ser_score

    @torch.no_grad()
    def client_evaluate(self, val=False):
        
        size = 0
        loss = 0
        correct = 0
        total = 0
        predicted_labels = []
        true_labels = []
        len_of_data=[]
        all_client_acc=[]
        
        models_directory = "/home/chiragapandav/Downloads/Hiwi/Improving-Fairness-via-Federated-Learning/FedFB/models/"
                 
        for client_id in self.selected_clients:
            
            model_path = os.path.join(models_directory, f"client_{client_id}_model.pth")            
        
            self.model.load_state_dict(torch.load(model_path,map_location=self.device))

            self.model.eval()

            for inputs, targets,sens in self.testset:
                
                inputs, targets = inputs.to(self.device), targets.to(self.device)
                
                outputs = self.model(inputs)
                loss += self.criterion(outputs, targets)
                predicted = outputs > 0.5
                
                predicted_labels.extend(predicted.cpu().numpy())
                true_labels.extend(targets.cpu().numpy())
                
                total += targets.size(0)
                correct += (predicted == targets).sum().item()
                
                len_of_data.append(len(self.testset))
            
            loss = loss / len(self.testset)
            acc = correct / total
            all_client_acc.append(acc)
                        
            precision = precision_score(true_labels, predicted_labels,zero_division=0.0)
            
            recall = recall_score(true_labels, predicted_labels)
                            
        return loss, acc, precision, recall,len_of_data, all_client_acc

    
    @torch.no_grad()
    def server_evaluate(self,path,all_client_acc,len_of_data):

        # print("Global Model testing Starts")
        # print("kindly check the Path. Select it based on FedAvg Model")
        # print("warning: Test data has already been used")
        temp_path_data="/home/chiragapandav/Downloads/Hiwi/Improving-Fairness-via-Federated-Learning/FedFB/"

        with open(temp_path_data+"/testing_client.pkl", "rb") as f:
            testset  = pickle.load(f)
            
        
        # path="/home/chiragapandav/Downloads/Hiwi/Improving-Fairness-via-Federated-Learning/FedFB/models/global_model_2.pt"
        
        self.model.load_state_dict(torch.load(path,map_location=self.device))
        self.model.eval()
        size = 0
        loss = 0
        correct = 0
        total = 0
        predicted_labels = []
        true_labels = []
        final_fairness=[]

        self.testset=testset[1]
        for inputs, targets,sens in self.testset:
            
            inputs, targets = inputs.to(self.device), targets.to(self.device)
            
            outputs = self.model(inputs)
            loss += self.criterion(outputs, targets)
            predicted = outputs > 0.5

            
#            print("predicted",predicted)
            predicted_labels.extend(predicted.cpu().numpy())
            true_labels.extend(targets.cpu().numpy())
            
            total += targets.size(0)
            correct += (predicted == targets).sum().item()
            
        
        loss = loss / len(self.testset)
        ser_acc = correct / total

        fairNess_per_batch=self.calculate_fairness(torch.round(outputs).squeeze(), sens.squeeze(), targets.squeeze(),
                                                    ser_acc,
                                                    all_client_acc,
                                                    len_of_data)
        final_fairness.append(fairNess_per_batch)
        fairness_global=np.mean(final_fairness)

        # print("loss: %f\n" % (loss))
        
        # print(f"Global Testing Accuracy: {ser_acc:.5%}")
        # print(f" Global Fairness: {fairness_global:.5f}%")
        precision = precision_score(true_labels, predicted_labels,zero_division=0.0)
        
        recall = recall_score(true_labels, predicted_labels)
        
        # print(f"Global Precision: {precision:.5%}")
        # print(f"Global Recall: {recall:.5%}")
        
        return loss, ser_acc, precision, recall,fairness_global
                                            

In [16]:
class Serverbase:
    def __init__ (self,model):
        self.model = model
        self.global_model=model
        self.num_rounds=2
        self.local_epoch=2
        self.optimizer=2
        self.lr=0.001
        self.beta=1
        self.batch_size=32
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 

    
        self.updated_params_cache = []
        self.weights_cache = []
        self.model_dict = []
        
        self.global_params_dict: OrderedDict[str : torch.Tensor] = None

        self.backbone=DeepNet
        _dummy_model = self.backbone()
        self.global_params_dict = OrderedDict(_dummy_model.state_dict())

        self.temp_dir="/home/chiragapandav/Downloads/Hiwi/Improving-Fairness-via-Federated-Learning/FedFB/models/"

        self.fair_global_t_step=0
        self.fair_local_client=0   
        
        self.acc_clients=[]
        self.clients_id=[]
        self.precision_clients=[]
        self.recall_clients=[]
        
        self.global_acc=0
        self.fairness_diff=[]
        self.clients_weights=[]
        self.final_agg_weights_clients=[]
        self.selected_clients=[0,1,2,3]
        self.statistics=0

    def global_fairness(self,weights, all_clients_fairness): 
        # all_clients_fairness is based on equation: 7
        # upper part is done in client class
        # self.statistics cal. pr(y=1,a=0)
        
        # print("global_fairness::",all_clients_fairness)
        # print("weights::",weights)
        
        weight_sum = sum(weights)  
        final_global_score=0
        
        for i in range(len(all_clients_fairness)):
            temp=weights[i]/weight_sum            
            final_global_score+=temp*all_clients_fairness[i]
            
        self.fair_global=final_global_score

    
        
    def fairFed(self, num_rounds = 2, local_epochs = 2, learning_rate = 0.001, beta = 1, optimizer = 'adam',algo="FF"):
        
        client = Client()
        path=f"/home/chiragapandav/Downloads/Hiwi/Improving-Fairness-via-Federated-Learning/FedFB/models/global_model_{algo}.pt"
        models_directory = "/home/chiragapandav/Downloads/Hiwi/Improving-Fairness-via-Federated-Learning/FedFB/models/"
        
        #before we start the training, we need statistics
        print("Getting data statistics....")
        self.statistics= client.get_stats()      
        

        for round_ in tqdm(range(num_rounds)):
  
            #each round we need empty lists
            self.updated_params_cache = []
            self.weights_cache = []
            self.model_dict=[]
            self.clients_weights=[]
            self.final_agg_weights_clients=[]
            self.fair_global=0
            temp_ser_fairness=[]

            print ("round::",round_)
            
             # clients and server acc for calculating Fairness whenever fairness_client is undefined
        
            loss_client, acc_client,precision, recall,len_of_data,all_client_acc = client.client_evaluate() 
            
            loss_client, accuracy_server, pre_server, recall_ser,fairness_global = client.server_evaluate(path,all_client_acc,len_of_data)

            
            for client_id in self.selected_clients:
                self.global_acc=0
                
                print("client_id:: ",client_id)
                # weight is length of dataset

                
                #take a model of particular client
                model_path = os.path.join(models_directory, f"client_{client_id}_model.pth")                       
                
                updated_params_list, weight, model_dict_list,fairness_client,ser_fairness = client.train(client_id, 
                                                                                                         model_path, 
                                                                                                         statistics=self.statistics,
                                                                                                         server_acc=accuracy_server,
                                                                                                         clients_acc=all_client_acc,
                                                                                                         len_client_data=len_of_data,
                                                                                                         num_epochs=local_epochs, 
                                                                                                         learning_rate=learning_rate)        
                            
                # store it for SecAgg
                self.updated_params_cache.append(updated_params_list)
                self.weights_cache.append(weight)
                self.model_dict.append(model_dict_list)
                temp_ser_fairness.append(ser_fairness)
                
                # print("temp_ser_fairness",temp_ser_fairness)
                #store it for FairFed agg
                
                # self.global_acc=acc_server
                self.fair_local_client=fairness_client                
            
            self.global_fairness(self.weights_cache, temp_ser_fairness)

            print("for round", round_,", global fairness, accuracy, and recall are: ",self.fair_global,", ",
                                                                                  accuracy_server,", ",
                                                                                  recall_ser,"\n\n\n")

            # Update the weights of each client
            # self.aggregate_client_weight_FedFair_part_1(self.model_dict, self.weights_cache,self.acc_clients, self.global_acc,self.fair_local_client,
            #                                      self.fair_global)
            # self.client_weights_update_part_2(self.clients_weights)

            # agg all the weights on serverside
            if algo=="FF":
                self.aggregate_parameters_FedAvg_updated(self.final_agg_weights_clients)
            elif algo=="FA":
                self.aggregate_parameters_FedAvg(self.updated_params_cache,self.weights_cache)
            else:
                print("--Choose Algo --")
                break
        
            
            torch.save(self.global_model.state_dict(), os.path.join(self.temp_dir, f"global_model_{algo}.pt"),)

            

    @torch.no_grad()
    def aggregate_parameters_FedAvg(self, updated_params_cache, weights_cache):
        weight_sum = sum(weights_cache)  
        
        weights = torch.tensor(weights_cache, device=self.device) / weight_sum
        
        aggregated_params = []

        for params in zip(*updated_params_cache):
            aggregated_params.append(
                torch.sum(weights * torch.stack(params, dim=-1), dim=-1)
            )

        self.global_params_dict = OrderedDict(
            zip(self.global_params_dict.keys(), aggregated_params)
        )

    @torch.no_grad()
    def aggregate_parameters_FedAvg_updated(self, model_dict_list):        
        w_avg = copy.deepcopy(model_dict_list[0])
        for k in w_avg.keys():
            for i in range(1, len(model_dict_list)):
                w_avg[k] += model_dict_list[i][k]
            w_avg[k] = torch.div(w_avg[k], len(model_dict_list))
    
        self.global_model.load_state_dict(w_avg)

        
    @torch.no_grad()
    def aggregate_client_weight_FedFair_part_1 (self, model_dict_list ,weights_cache, acc_clients, global_acc,fairness_client,fairness_global):
        

        Delta_t_C_i = abs(self.fair_global - self.fair_local_client) 
        
        for client_id in range(len(self.selected_clients)):
        
            w_avg = copy.deepcopy(model_dict_list[client_id])
            temp = copy.deepcopy(model_dict_list[client_id])
            
            for k in w_avg.keys():
                for i in range(1, len(model_dict_list)):
                    w_avg[k] += model_dict_list[i][k]
                w_avg[k] = torch.div(w_avg[k], len(model_dict_list))
                
                w_avg[k]=temp[k]-(self.beta*(Delta_t_C_i-w_avg[k]))
                
            self.clients_weights.append(w_avg)

            
    def client_weights_update_part_2(self,agg_weights):
        sum_params = OrderedDict()
        
        for model in agg_weights:
            for name, param in model.items():
                if name in sum_params:
                    sum_params[name] += param
                else:
                    sum_params[name] = param

            updated_client_weights = OrderedDict()
            
            for name, param in model.items():
                updated_param = torch.div(param, sum_params[name])
                updated_client_weights[name] = updated_param

                self.final_agg_weights_clients.append(updated_client_weights)
        
  

In [17]:
# FairFed "FF" and FedAVG "FA"

algo="FA" 

if algo=="FA":
    print("Comment out part1 and part2 def_fairFed ")
    
else:
    print("Algorithm is ", algo)
    
global_model_base_path = "/home/chiragapandav/Downloads/Hiwi/Improving-Fairness-via-Federated-Learning/FedFB/models"
shutil.rmtree(global_model_base_path)
os.mkdir(global_model_base_path)

selected_clients = [0, 1, 2, 3]

assigner = AssignModel(global_model_base_path, selected_clients,algo)

temp_model = DeepNet()  # Replace with your model instantiation
assigner.save_global_model(temp_model)
assigner.save_client_models(temp_model)


model = DeepNet()
server = Serverbase(model)
server.fairFed(num_rounds=2,local_epochs=5,learning_rate = 0.001, beta = 1,algo=algo)  

Comment out part1 and part2 def_fairFed 
Getting data statistics....


  0%|                                                                   | 0/2 [00:00<?, ?it/s]

round:: 0
client_id::  0
Epoch 1/5, Val Accuracy: 0.76212%
Epoch 1/5, Val Fairness: 0.01208%
Epoch 1/5, Val Recall: 0.80584%
Epoch 2/5, Val Accuracy: 0.78312%
Epoch 2/5, Val Fairness: 0.01208%
Epoch 2/5, Val Recall: 0.69994%
Epoch 3/5, Val Accuracy: 0.78745%
Epoch 3/5, Val Fairness: 0.01208%
Epoch 3/5, Val Recall: 0.69507%
Epoch 4/5, Val Accuracy: 0.79286%
Epoch 4/5, Val Fairness: 0.01208%
Epoch 4/5, Val Recall: 0.70359%
Epoch 5/5, Val Accuracy: 0.78225%
Epoch 5/5, Val Fairness: 0.01208%
Epoch 5/5, Val Recall: 0.58186%
client_id::  1
Epoch 1/5, Val Accuracy: 0.75541%
Epoch 1/5, Val Fairness: 0.01208%
Epoch 1/5, Val Recall: 0.82424%
Epoch 2/5, Val Accuracy: 0.76397%
Epoch 2/5, Val Fairness: 0.01208%
Epoch 2/5, Val Recall: 0.81458%
Epoch 3/5, Val Accuracy: 0.77925%
Epoch 3/5, Val Fairness: 0.01208%
Epoch 3/5, Val Recall: 0.80444%
Epoch 4/5, Val Accuracy: 0.78581%
Epoch 4/5, Val Fairness: 0.01208%
Epoch 4/5, Val Recall: 0.75761%
Epoch 5/5, Val Accuracy: 0.78708%
Epoch 5/5, Val Fairness: 0

 50%|█████████████████████████████▌                             | 1/2 [00:28<00:28, 28.86s/it]

Epoch 5/5, Val Accuracy: 0.78100%
Epoch 5/5, Val Fairness: 0.01208%
Epoch 5/5, Val Recall: 0.76870%
for round 0 , global fairness, accuracy, and recall are:  0.012078994821371702 ,  0.3668816454592993 ,  1.0 



round:: 1
client_id::  0
Epoch 1/5, Val Accuracy: 0.80216%
Epoch 1/5, Val Fairness: 0.14555%
Epoch 1/5, Val Recall: 0.78698%
Epoch 2/5, Val Accuracy: 0.80606%
Epoch 2/5, Val Fairness: 0.14555%
Epoch 2/5, Val Recall: 0.73707%
Epoch 3/5, Val Accuracy: 0.80346%
Epoch 3/5, Val Fairness: 0.14555%
Epoch 3/5, Val Recall: 0.79245%
Epoch 4/5, Val Accuracy: 0.80974%
Epoch 4/5, Val Fairness: 0.14555%
Epoch 4/5, Val Recall: 0.69142%
Epoch 5/5, Val Accuracy: 0.80974%
Epoch 5/5, Val Fairness: 0.14555%
Epoch 5/5, Val Recall: 0.76446%
client_id::  1
Epoch 1/5, Val Accuracy: 0.79126%
Epoch 1/5, Val Fairness: 0.14555%
Epoch 1/5, Val Recall: 0.76678%
Epoch 2/5, Val Accuracy: 0.79308%
Epoch 2/5, Val Fairness: 0.14555%
Epoch 2/5, Val Recall: 0.75278%
Epoch 3/5, Val Accuracy: 0.79108%
Epoch 3/5, Val

100%|███████████████████████████████████████████████████████████| 2/2 [01:01<00:00, 30.69s/it]

Epoch 5/5, Val Accuracy: 0.78040%
Epoch 5/5, Val Fairness: 0.14555%
Epoch 5/5, Val Recall: 0.85535%
for round 1 , global fairness, accuracy, and recall are:  0.14555174742180343 ,  0.6331183545407008 ,  0.0 






