# Multiple-machines Federated Learning
Choose your role! We need one server, one/two malicious clients, and several honest clients.

Requirements:
- Run ```pip install -r requirements.txt```
- Run ```ifconfig``` to find your IP address, and check under en0 (for Wi-Fi) or en7 (for Ethernet/cable)

## Server

In [1]:
# Libraries
import flwr as fl
import numpy as np
from typing import List, Tuple, Union, Optional, Dict
from flwr.common import Parameters, Scalar, Metrics
from flwr.server.client_proxy import ClientProxy
from flwr.common import FitRes
import argparse
import torch
import utils
import os
from collections import OrderedDict
import json
import time
import pandas as pd

In [2]:
# Define functions
# Config_client
def fit_config(server_round: int):
    """Return training configuration dict for each round."""
    config = {
        "current_round": server_round,
        "local_epochs": 2,
        "tot_rounds": 20,
    }
    return config

# Custom weighted average function
def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics:
    # Multiply accuracy of each client by number of examples used
    accuracies = [num_examples * m["accuracy"] for num_examples, m in metrics]
    validities = [num_examples * m["validity"] for num_examples, m in metrics]
    examples = [num_examples for num_examples, _ in metrics]
    # Aggregate and return custom metric (weighted average)
    return {"accuracy": sum(accuracies) / sum(examples), "validity": sum(validities) / sum(examples)}

# Custom strategy to save model after each round
class SaveModelStrategy(fl.server.strategy.FedAvg):
    def __init__(self, model, data_type, checkpoint_folder, dataset, fold, model_config, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.model = model
        self.data_type = data_type
        self.checkpoint_folder = checkpoint_folder
        self.dataset = dataset
        self.model_config = model_config
        self.fold = fold

        # read data for testing
        self.X_test, self.y_test = utils.load_data_test(data_type=self.data_type, dataset=self.dataset)

        if self.dataset == 'diabetes':
            # randomly pick N samples <= 10605
            idx = np.random.choice(len(self.X_test), 300, replace=False)
            self.X_test = self.X_test[idx]
            self.y_test = self.y_test[idx]
        elif self.dataset == 'breast':
            # randomly pick N samples <= 89
            idx = np.random.choice(len(self.X_test), 88, replace=False)
            self.X_test = self.X_test[idx]
            self.y_test = self.y_test[idx] 
        elif self.dataset == 'synthetic':
            # randomly pick N samples <= 938
            idx = np.random.choice(len(self.X_test), 300, replace=False)
            self.X_test = self.X_test[idx]
            self.y_test = self.y_test[idx]
        elif self.dataset == 'mnist':
            # randomly pick N samples <= 938
            idx = np.random.choice(len(self.X_test), 300, replace=False)
            self.X_test = self.X_test[idx]
            self.y_test = self.y_test[idx] 
        elif self.dataset == 'cifar10':
            # randomly pick N samples <= 938
            idx = np.random.choice(len(self.X_test), 280, replace=False)
            self.X_test = self.X_test[idx]
            self.y_test = self.y_test[idx]      
        
        print(f"Used Size Server-Test Set: {self.X_test.shape}")

        # create folder if not exists
        if not os.path.exists(self.checkpoint_folder + f"{self.data_type}"):
            os.makedirs(self.checkpoint_folder + f"{self.data_type}")

    # Override aggregate_fit method to add saving functionality
    def aggregate_fit(
        self,
        server_round: int,
        results: List[Tuple[fl.server.client_proxy.ClientProxy, fl.common.FitRes]],
        failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]],
    ) -> Tuple[Optional[Parameters], Dict[str, Scalar]]:
        """Aggregate model weights using weighted average and store checkpoint"""

        # Perform evaluation on the server side on each single client after local training for each clients evaluate the model
        client_data = {}
        for client, fit_res in results:
            # Load model
            params = fl.common.parameters_to_ndarrays(fit_res.parameters)
            params_dict = zip(self.model.state_dict().keys(), params)
            state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
            cid = int(np.round(state_dict['cid'].item()))
            self.model.load_state_dict(state_dict, strict=True)
            # Evaluate the model
            try:
                client_metrics = utils.server_side_evaluation(self.X_test, self.y_test, model=self.model, config=self.model_config)
                client_data[cid] = client_metrics
            except Exception as e:
                print(f"An error occurred during server-side evaluation of client {cid}: {e}, returning zero metrics") 

        # Planes construction
        utils.creation_planes_FBPs(client_data, server_round, self.data_type, self.dataset, self.model_config, self.fold)
        
        # Call aggregate_fit from base class (FedAvg) to aggregate parameters and metrics
        aggregated_parameters, aggregated_metrics = super().aggregate_fit(server_round, results, failures) # aggregated_metrics from aggregate_fit is empty except if i pass fit_metrics_aggregation_fn

        # Save model
        if aggregated_parameters is not None:

            print(f"Saving round {server_round} aggregated_parameters...")
            # Convert `Parameters` to `List[np.ndarray]`
            aggregated_ndarrays: List[np.ndarray] = fl.common.parameters_to_ndarrays(aggregated_parameters)
            # Convert `List[np.ndarray]` to PyTorch`state_dict`
            params_dict = zip(self.model.state_dict().keys(), aggregated_ndarrays)
            state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
            self.model.load_state_dict(state_dict, strict=True)
            # Save the model
            torch.save(self.model.state_dict(), self.checkpoint_folder + f"{self.data_type}/model_round_{server_round}.pth")
        
        return aggregated_parameters, aggregated_metrics

In [3]:
# Setup arguments and initialize strategy
class Args:
    rounds = 50
    data_type = '2cluster'
    dataset = 'diabetes'
    model = 'net'
    pers = 0
    n_clients = 5
    n_attackers = 1
    attack_type = 'DP_flip'
    fold = 0

args = Args()

In [4]:
# Start server
if not os.path.exists(f"results/{args.model}/{args.dataset}/{args.data_type}/{args.fold}"):
    os.makedirs(f"results/{args.model}/{args.dataset}/{args.data_type}/{args.fold}")
else:
    # remove the directory and create a new one
    os.system(f"rm -r results/{args.model}/{args.dataset}/{args.data_type}/{args.fold}")
    os.makedirs(f"results/{args.model}/{args.dataset}/{args.data_type}/{args.fold}")

# model and history folder
model = utils.models[args.model]
config = utils.config_tests[args.dataset][args.model]

# Define strategy
strategy = SaveModelStrategy(
    model=model(config=config), # model to be trained
    min_fit_clients=args.n_clients+args.n_attackers, # Never sample less than 10 clients for training
    min_evaluate_clients=args.n_clients+args.n_attackers,  # Never sample less than 5 clients for evaluation
    min_available_clients=args.n_clients+args.n_attackers, # Wait until all 10 clients are available
    fraction_fit=1.0, # Sample 100 % of available clients for training
    fraction_evaluate=1.0, # Sample 100 % of available clients for evaluation
    evaluate_metrics_aggregation_fn=weighted_average,
    on_evaluate_config_fn=fit_config,
    on_fit_config_fn=fit_config,
    data_type=args.data_type,
    checkpoint_folder=config['checkpoint_folder'],
    dataset=args.dataset,
    fold=args.fold,
    model_config=config,
)

# Start time
start_time = time.time()

# Start Flower server for three rounds of federated learning
history = fl.server.start_server(
    server_address="0.0.0.0:8098",   # 0.0.0.0 listens to all available interfaces
    config=fl.server.ServerConfig(num_rounds=args.rounds),
    strategy=strategy,
)

# Print training time in minutes (grey color)
training_time = (time.time() - start_time)/60
print(f"\033[90mTraining time: {round(training_time, 2)} minutes\033[0m")
time.sleep(1)

# convert history to list
loss = [k[1] for k in history.losses_distributed]
accuracy = [k[1] for k in history.metrics_distributed['accuracy']]
validity = [k[1] for k in history.metrics_distributed['validity']]

# Save loss and accuracy to a file
print(f"Saving metrics to as .json in histories folder...")
# # check if folder exists and save metrics
if not os.path.exists(config['history_folder'] + f"server_{args.data_type}"):
    os.makedirs(config['history_folder'] + f"server_{args.data_type}")
with open(config['history_folder'] + f'server_{args.data_type}/metrics_{args.rounds}_{args.attack_type}_{args.n_attackers}_none_{args.fold}.json', 'w') as f:
    json.dump({'loss': loss, 'accuracy': accuracy, 'validity':validity}, f)

# Single Plot
best_loss_round, best_acc_round = utils.plot_loss_and_accuracy(args, loss, accuracy, validity, config=config, show=False)

# Evaluate the model on the test set
if args.model == 'predictor':
    y_test_pred, accuracy = utils.evaluation_central_test_predictor(args, best_model_round=best_loss_round, config=config)
    print(f"Accuracy on test set: {accuracy}")
    df_excel = {}
    df_excel['accuracy'] = [accuracy]
    df_excel = pd.DataFrame(df_excel)
    df_excel.to_excel(f"results_fold_{args.fold}.xlsx")
else:
    utils.evaluation_central_test(args, best_model_round=best_loss_round, model=model, config=config)
    
    # Evaluate distance with all training sets
    df_excel = utils.evaluate_distance(args, best_model_round=best_loss_round, model_fn=model, config=config, spec_client_val=False, training_time=training_time)
    if args.fold != 0:
        df_excel.to_excel(f"results_fold_{args.fold}.xlsx")

# personalization (now done on the server but can be uqually done on the client side) 
if args.pers == 1:
    start_time = time.time()
    # Personalization
    print("\n\n\n\n\033[94mPersonalization\033[0m")
    df_excel_list = utils.personalization(args, model_fn=model, config=config, best_model_round=best_loss_round)
    if args.fold != 0:
        for i in range(args.n_clients):
            print(f"Saving results_fold_{args.fold}_personalization_{i+1}.xlsx")
            df_excel_list[i].to_excel(f"results_fold_{args.fold}_personalization_{i+1}.xlsx")

    # Print training time in minutes (grey color)
    print(f"\033[90mPersonalization time: {round((time.time() - start_time)/60, 2)} minutes\033[0m")

# Create gif
utils.create_gif(args, config)

INFO flwr 2024-08-28 14:49:31,190 | app.py:163 | Starting Flower server, config: ServerConfig(num_rounds=50, round_timeout=None)
INFO flwr 2024-08-28 14:49:31,208 | app.py:176 | Flower ECE: gRPC server running (50 rounds), SSL is disabled
INFO flwr 2024-08-28 14:49:31,208 | server.py:89 | Initializing global parameters
INFO flwr 2024-08-28 14:49:31,209 | server.py:276 | Requesting initial parameters from one random client


Used Size Server-Test Set: torch.Size([300, 21])


KeyboardInterrupt: 

## Malicious Client

In [5]:
# Libraies
from collections import OrderedDict
import torch
import utils
import flwr as fl
import argparse
import numpy as np


In [6]:
# Define the arguments directly in the notebook
class Args:
    id = 1  # Example: Set the id to 1 (adjust as needed)
    data_type = '2cluster'  # Choose between 'random', 'cluster', '2cluster'
    dataset = 'diabetes'  # Choose between 'diabetes', 'breast', 'synthetic', 'mnist', 'cifar10'
    model = 'net'  # Choose between 'net', 'vcnet', 'predictor'
    attack_type = 'MP_random'  # Choose the attack type or set to 'None'

args = Args()

In [7]:
# Define functions 
# Define Flower client
class FlowerClient(fl.client.NumPyClient):
    def __init__(self, model, X_train, y_train, X_val, y_val, optimizer, num_examples, 
                 client_id, data_type, train_fn, evaluate_fn, attack_type, config_model):
        self.model = model
        self.X_train = X_train
        self.y_train = y_train
        self.X_val = X_val
        self.y_val = y_val
        self.loss_fn = utils.InvertedLoss() if attack_type=="DP_inverted_loss" else torch.nn.CrossEntropyLoss()
        self.optimizer = optimizer
        self.num_examples = num_examples
        self.client_id = client_id 
        self.data_type = data_type
        self.train_fn = train_fn
        self.evaluate_fn = evaluate_fn
        self.history_folder = config_model['history_folder']
        self.config_model = config_model
        self.attack_type = attack_type
        self.saved_models = {} # Save the parameters of the previous rounds

    def get_parameters(self, config):
        params = []
        for k, v in self.model.state_dict().items():
            if k == 'cid':
                params.append(np.array([self.client_id + 100]))
                continue
            if k == 'mask' or k=='binary_feature':
                params.append(v.cpu().numpy())
                continue
            # Original parameters
            if self.attack_type in ["None", "DP_flip", "DP_random", "DP_inverted_loss", "DP_inverted_loss_cf"]:
                params.append(v.cpu().numpy())
            # Mimic the actual parameter range by observing the mean and std of each parameter
            elif self.attack_type == "MP_random":
                v = v.cpu().numpy()
                params.append(np.random.normal(loc=np.mean(v), scale=np.std(v), size=v.shape).astype(np.float32))
            # Introducing random noise to the parameters
            elif self.attack_type == "MP_noise":
                v = v.cpu().numpy()
                params.append(v + np.random.normal(0, 1.2*np.std(v), v.shape).astype(np.float32))   
            # Gradient-based attack - flip the sign of the gradient and scale it by a factor [adaptation of Fall of Empires]
            elif self.attack_type == "MP_gradient": # Fall of Empires
                if config["current_round"] == 1:
                    params.append(v.cpu().numpy()) # Use the original parameters for the first round
                    continue
                else:
                    epsilon = 0.1 # from 0 to 10 --- reverse gradient when epsilon is equal to learning rate
                    learning_rate = 0.01
                    prev_v = self.saved_models.get(config["current_round"] - 1).get(k).cpu().numpy()
                    current_v = v.cpu().numpy()
                    gradient = (prev_v - current_v)/learning_rate # precisely mean gradients from all the other clients
                    manipulated_param = current_v + epsilon * gradient  # apply gradient in the opposite direction
                    params.append(manipulated_param.astype(np.float32))

        return params
    
    def set_parameters(self, parameters):
        params_dict = zip(self.model.state_dict().keys(), parameters)
        state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
        self.model.load_state_dict(state_dict, strict=True)
    
    def fit(self, parameters, config):
        self.set_parameters(parameters)
        if self.attack_type in ["None", "DP_flip", "DP_random", "DP_inverted_loss"]:
            try:
                model_trained, train_loss, val_loss, acc, acc_prime, acc_val, _ = self.train_fn(
                    self.model, self.loss_fn, self.optimizer, self.X_train, self.y_train, 
                    self.X_val, self.y_val, n_epochs=config["local_epochs"], print_info=False, config=self.config_model)
            except Exception as e:
                # print(f"An error occurred during training of Malicious client: {e}, returning model with error") 
                print(f"An error occurred during training of Malicious client, returning model with error") 

        elif self.attack_type in ["DP_inverted_loss_cf"]:
            try:
                model_trained, train_loss, val_loss, acc, acc_prime, acc_val, _ = self.train_fn(
                    self.model, self.loss_fn, self.optimizer, self.X_train, self.y_train, 
                    self.X_val, self.y_val, n_epochs=config["local_epochs"], print_info=False, config=self.config_model, inv_loss_cf=True)
            except Exception as e:
                # print(f"An error occurred during training of Malicious client: {e}, returning model with error") 
                print(f"An error occurred during training of Malicious client, returning model with error")

        elif self.attack_type == "MP_gradient":
            self.saved_models[config["current_round"]] = {k: v.clone() for k, v in self.model.state_dict().items()}
            # delede previous 3-rounds model
            if config["current_round"] > 3:
                del self.saved_models[config["current_round"]-3]
        return self.get_parameters(config), self.num_examples["trainset"], {}

    def evaluate(self, parameters, config):
        self.set_parameters(parameters)
        if self.model.__class__.__name__ == "Predictor":
            try:
                loss, accuracy = utils.evaluate_predictor(self.model, self.X_val, self.y_val, self.loss_fn, config=self.config_model)
                # save loss and accuracy client
                utils.save_client_metrics(config["current_round"], loss, accuracy, 0, client_id=self.client_id,
                                        data_type=self.data_type, tot_rounds=config['tot_rounds'], history_folder=self.history_folder)
                return float(loss), self.num_examples["valset"], {"accuracy": float(accuracy), "mean_distance": float(0), "validity": float(0)}
            except Exception as e:
                #print(f"An error occurred during inference of Malicious client: {e}, returning same zero metrics") 
                print(f"An error occurred during inference of Malicious client, returning same zero metrics")
                return float(10000), self.num_examples["valset"], {"accuracy": float(0), "mean_distance": float(10000), "validity": float(0)}

        else:
            try:
                loss, accuracy, validity, mean_proximity, hamming_distance, euclidian_distance, iou, variability = self.evaluate_fn(self.model, self.X_val, self.y_val, self.loss_fn, self.X_train, self.y_train, config=self.config_model)
                # save loss and accuracy client
                utils.save_client_metrics(config["current_round"], loss, accuracy, validity, mean_proximity, hamming_distance, euclidian_distance, iou, variability,
                                        self.client_id, self.data_type, config['tot_rounds'], self.history_folder)
                return float(loss), self.num_examples["valset"], {"accuracy": float(accuracy), "proximity": float(mean_proximity), "validity": float(validity),
                                                                "hamming_distance": float(hamming_distance), "euclidian_distance": float(euclidian_distance),
                                                                "iou": float(iou), "variability": float(variability)}
            except Exception as e:
                # print(f"An error occurred during inference of Malicious client: {e}, returning same zero metrics") 
                print(f"An error occurred during inference of Malicious client, returning same zero metrics")
                return float(10000), self.num_examples["valset"], {"accuracy": float(0), "proximity": float(10000), "validity": float(0),
                                                                "hamming_distance": float(10000), "euclidian_distance": float(10000),
                                                                "iou": float(0), "variability": float(0)}


In [8]:
# Start the training
# model and history folder
model = utils.models[args.model]
train_fn = utils.trainings[args.model]
evaluate_fn = utils.evaluations[args.model]
plot_fn = utils.plot_functions[args.model]
config = utils.config_tests[args.dataset][args.model]

# check if metrics.csv exists otherwise delete it
utils.check_and_delete_metrics_file(config['history_folder'] + f"malicious_client_{args.data_type}_{args.attack_type}_{args.id}", question=False)

# check gpu and set manual seed
device = utils.check_gpu(manual_seed=True)

# load data
X_train, y_train, X_val, y_val, X_test, y_test, num_examples = utils.load_data_malicious(
    client_id=str(args.id), device=device, type=args.data_type, dataset=args.dataset, attack_type=args.attack_type)

# Model
model = model(config=config).to(device)

# Optimizer and Loss function
optimizer = torch.optim.SGD(model.parameters(), lr=config["learning_rate"], momentum=0.9)

# Start Flower client
client = FlowerClient(model, X_train, y_train, X_val, y_val, optimizer, num_examples, args.id, args.data_type,
                        train_fn, evaluate_fn, args.attack_type, config).to_client()
fl.client.start_client(server_address="[::]:8098", client=client) # local host


MPS is available


INFO flwr 2024-08-28 14:50:07,419 | grpc.py:52 | Opened insecure gRPC connection (no certificates were passed)
DEBUG flwr 2024-08-28 14:50:07,420 | connection.py:42 | ChannelConnectivity.IDLE
DEBUG flwr 2024-08-28 14:50:07,421 | connection.py:42 | ChannelConnectivity.CONNECTING
DEBUG flwr 2024-08-28 14:50:07,422 | connection.py:42 | ChannelConnectivity.READY
DEBUG flwr 2024-08-28 14:50:14,827 | connection.py:141 | gRPC channel closed


KeyboardInterrupt: 

## Honest client

In [9]:
# Libraies
from collections import OrderedDict
import torch
import utils
import flwr as fl
import argparse

In [10]:
# Define functions
# Define Flower client )
class FlowerClient(fl.client.NumPyClient):
    def __init__(self, model, X_train, y_train, X_val, y_val, optimizer, num_examples, 
                 client_id, data_type, train_fn, evaluate_fn, config_model):
        self.model = model
        self.X_train = X_train
        self.y_train = y_train
        self.X_val = X_val
        self.y_val = y_val
        self.loss_fn = torch.nn.CrossEntropyLoss()
        self.optimizer = optimizer
        self.num_examples = num_examples
        self.client_id = client_id
        self.data_type = data_type
        self.train_fn = train_fn
        self.evaluate_fn = evaluate_fn
        self.history_folder = config_model['history_folder']
        self.config = config_model

    def get_parameters(self, config):
        self.model.set_client_id(self.client_id)
        return [val.cpu().numpy() for _, val in self.model.state_dict().items()]

    def set_parameters(self, parameters):
        params_dict = zip(self.model.state_dict().keys(), parameters)
        state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
        self.model.load_state_dict(state_dict, strict=True)

    def fit(self, parameters, config):
        try: 
            self.set_parameters(parameters)
            model_trained, train_loss, val_loss, acc, acc_prime, acc_val, _ = self.train_fn(
                self.model, self.loss_fn, self.optimizer, self.X_train, self.y_train, 
                self.X_val, self.y_val, n_epochs=config["local_epochs"], print_info=False, config=self.config)
    
        except Exception as e:
            print(f"An error occurred during training of Honest client {self.client_id}: {e}, returning model with error") 
        
        return self.get_parameters(config), self.num_examples["trainset"], {}
    

    def evaluate(self, parameters, config):
        self.set_parameters(parameters)
        if self.model.__class__.__name__ == "Predictor":
            try:
                loss, accuracy = utils.evaluate_predictor(self.model, self.X_val, self.y_val, self.loss_fn, config=self.config)
                # save loss and accuracy client
                utils.save_client_metrics(config["current_round"], loss, accuracy, 0, client_id=self.client_id,
                                        data_type=self.data_type, tot_rounds=config['tot_rounds'], history_folder=self.history_folder)
                return float(loss), self.num_examples["valset"], {"accuracy": float(accuracy), "mean_distance": float(0), "validity": float(0)}
            except Exception as e:
                print(f"An error occurred during inference of client {self.client_id}: {e}, returning same zero metrics") 
                return float(10000), self.num_examples["valset"], {"accuracy": float(0), "mean_distance": float(10000), "validity": float(0)}

        else:
            try:
                loss, accuracy, validity, mean_proximity, hamming_distance, euclidian_distance, iou, variability = self.evaluate_fn(self.model, self.X_val, self.y_val, self.loss_fn, self.X_train, self.y_train, config=self.config)
                # save loss and accuracy client
                utils.save_client_metrics(config["current_round"], loss, accuracy, validity, mean_proximity, hamming_distance, euclidian_distance, iou, variability,
                                        self.client_id, self.data_type, config['tot_rounds'], self.history_folder)
                return float(loss), self.num_examples["valset"], {"accuracy": float(accuracy), "proximity": float(mean_proximity), "validity": float(validity),
                                                                "hamming_distance": float(hamming_distance), "euclidian_distance": float(euclidian_distance),
                                                                "iou": float(iou), "variability": float(variability)}
            except Exception as e:
                print(f"An error occurred during inference of client {self.client_id}: {e}, returning same zero metrics") 
                return float(10000), self.num_examples["valset"], {"accuracy": float(0), "proximity": float(10000), "validity": float(0),
                                                                "hamming_distance": float(10000), "euclidian_distance": float(10000),
                                                                "iou": float(0), "variability": float(0)}


In [11]:
# Define the arguments directly in the notebook
class Args:
    id = 1  # Example: Set the id to 1 (adjust as needed, within range 1-100)
    data_type = '2cluster'  # Choose between 'random', 'cluster', '2cluster'
    dataset = 'diabetes'  # Choose between 'diabetes', 'breast', 'synthetic', 'mnist', 'cifar10'
    model = 'net'  # Choose between 'net', 'vcnet', 'predictor'

# Instantiate the Args class
args = Args()

In [12]:
# Start training
# model and history folder
model = utils.models[args.model]
train_fn = utils.trainings[args.model]
evaluate_fn = utils.evaluations[args.model]
plot_fn = utils.plot_functions[args.model]
config = utils.config_tests[args.dataset][args.model]

# check if metrics.csv exists otherwise delete it
utils.check_and_delete_metrics_file(config['history_folder'] + f"client_{args.data_type}_{args.id}", question=False)

# check gpu and set manual seed
device = utils.check_gpu(manual_seed=True)

# load data
X_train, y_train, X_val, y_val, X_test, y_test, num_examples = utils.load_data(
    client_id=str(args.id), device=device, type=args.data_type, dataset=args.dataset)

# Model
model = model(config=config).to(device)

# Optimizer and Loss function
optimizer = torch.optim.SGD(model.parameters(), lr=config["learning_rate"], momentum=0.9)

# Start Flower client
client = FlowerClient(model, X_train, y_train, X_val, y_val, optimizer, num_examples, args.id, args.data_type,
                        train_fn, evaluate_fn, config).to_client()
fl.client.start_client(server_address="[::]:8098", client=client) # local host

# read saved data and plot
plot_fn(args.id, args.data_type, config, show=False)

INFO flwr 2024-08-28 14:50:22,989 | grpc.py:52 | Opened insecure gRPC connection (no certificates were passed)
DEBUG flwr 2024-08-28 14:50:23,002 | connection.py:42 | ChannelConnectivity.IDLE


MPS is available


DEBUG flwr 2024-08-28 14:50:23,211 | connection.py:141 | gRPC channel closed


_MultiThreadedRendezvous: <_MultiThreadedRendezvous of RPC that terminated with:
	status = StatusCode.UNAVAILABLE
	details = "failed to connect to all addresses; last error: UNKNOWN: ipv6:%5B::%5D:8098: Failed to connect to remote host: Connection refused"
	debug_error_string = "UNKNOWN:Error received from peer  {grpc_message:"failed to connect to all addresses; last error: UNKNOWN: ipv6:%5B::%5D:8098: Failed to connect to remote host: Connection refused", grpc_status:14, created_time:"2024-08-28T14:50:23.004429+02:00"}"
>