<a href="https://colab.research.google.com/github/mahinuralam/notebooks/blob/main/FL_DDoS.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install -q flwr["simulation"] tensorflow

In [None]:
!pip install flwr



In [None]:
import math
from typing import Dict, List, Tuple, Any, Optional, Union
import logging
import asyncio
import threading
import traceback
import warnings
import sys
import time
import os

import flwr as fl
from flwr.common import Metrics
from flwr.simulation.ray_transport.utils import enable_tf_gpu_growth
# from flwr.server import init_defaults, run_fl
from flwr.server.server_config import ServerConfig
from flwr.server.history import History
from flwr.server.strategy import Strategy
from flwr.simulation.ray_transport.ray_actor import (
    ClientAppActor,
    VirtualClientEngineActor,
    VirtualClientEngineActorPool,
    pool_size_from_resources,
)
from flwr.simulation.ray_transport.ray_client_proxy import RayActorClientProxy

import ray
from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy

# Machine learning imports
import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow.keras import Sequential, layers
from tensorflow.keras.layers import (
    Input, Conv1D, MaxPool1D, Flatten, Dense, LSTM, Dropout, BatchNormalization
)
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.keras.utils import to_categorical, plot_model

# Scikit-learn imports
from sklearn.model_selection import train_test_split, KFold
from sklearn.metrics import (
    accuracy_score, precision_score, recall_score, f1_score, precision_recall_fscore_support, roc_auc_score
)
from sklearn.preprocessing import (
    StandardScaler, LabelEncoder, MinMaxScaler, LabelBinarizer
)
from sklearn.tree import DecisionTreeClassifier
from sklearn.feature_selection import RFE
from sklearn.neighbors import NearestNeighbors
from sklearn.cluster import DBSCAN

# Imbalanced data handling
from imblearn.over_sampling import SMOTE, RandomOverSampler

# Visualization
import matplotlib.pyplot as plt
import seaborn as sns
from tabulate import tabulate
from tqdm import tqdm  # Notebook is not necessary; replaced with the base version




In [None]:
import math
from typing import List, Tuple
import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow.keras import Sequential
from tensorflow.keras.layers import Conv1D, MaxPool1D, LSTM, Dense
from tensorflow.keras.optimizers import Adam
from sklearn.tree import DecisionTreeClassifier
from sklearn.feature_selection import RFE
import flwr as fl

# Configuration constants
NUM_CLIENTS = 2
FEATURES = 46
EPOCHS = 1
BATCH_SIZE = 32

# Function to load and process data
def load_and_process_data():
    # Load datasets
    df = pd.read_csv("/content/part-00001-363d1ba3-8ab5-4f96-bc25-4d5862db7cb9-c000.csv")
    df_test = pd.read_csv("/content/part-00005-363d1ba3-8ab5-4f96-bc25-4d5862db7cb9-c000.csv")

    # Label mapping
    def change_label(df):
        df.label.replace(['DDoS-ICMP_Flood','DDoS-UDP_Flood','DDoS-TCP_Flood','DDoS-PSHACK_Flood','DDoS-SYN_Flood','DDoS-RSTFINFlood','DDoS-SynonymousIP_Flood','DDoS-ICMP_Fragmentation','DDoS-UDP_Fragmentation','DDoS-ACK_Fragmentation','DDoS-HTTP_Flood','DDoS-SlowLoris'],'DDos',inplace=True)
        df.label.replace(['DoS-UDP_Flood','DoS-TCP_Flood','DoS-SYN_Flood','DoS-HTTP_Flood'],'DoS',inplace=True)
        df.label.replace(['Recon-HostDiscovery','Recon-OSScan','Recon-PortScan','Recon-PingSweep','VulnerabilityScan'],'Recon',inplace=True)
        df.label.replace(['MITM-ArpSpoofing','DNS_Spoofing'],'Spoofing',inplace=True)
        df.label.replace(['DictionaryBruteForce'],'BruteForce',inplace=True)
        df.label.replace(['BrowserHijacking','XSS','Uploading_Attack','SqlInjection','CommandInjection','Backdoor_Malware'],'Web-based',inplace=True)
        df.label.replace(['Mirai-greeth_flood','Mirai-udpplain','Mirai-greip_flood'],'Mirai',inplace=True)
        df.label.replace(['BenignTraffic'],'BENIGN',inplace=True)


    # Apply label changes to training and testing datasets
    change_label(df)
    change_label(df_test)

    # Filter out classes that we are not interested in
    df_DDOS = df[df['label'].isin(['DDos', 'BENIGN'])]
    df_DDOS_test = df_test[df_test['label'].isin(['DDos', 'BENIGN'])]

    # Map class labels to numeric
    class_mapping = {'BENIGN': 0, 'DDos': 1}
    Y_TRAIN = df_DDOS['label'].map(class_mapping)
    Y_TEST = df_DDOS_test['label'].map(class_mapping)

    # Drop the label column to obtain only features
    X_TRAIN = df_DDOS.drop('label', axis=1).copy()
    X_TEST = df_DDOS_test.drop('label', axis=1).copy()

    # Convert columns to float32
    X_TRAIN = X_TRAIN.astype(np.float32)
    X_TEST = X_TEST.astype(np.float32)

    # Perform feature selection with RFE
    clf = DecisionTreeClassifier(random_state=0)
    rfe = RFE(estimator=clf, n_features_to_select=FEATURES, step=1)
    rfe.fit(X_TRAIN, Y_TRAIN)

    # Verify selected feature count
    if rfe.n_features_ != FEATURES:
        raise ValueError(f"Expected {FEATURES} features but got {rfe.n_features_}")

    # Transform training and testing data to have the specified number of features
    X_TRAIN = rfe.transform(X_TRAIN)
    X_TEST = rfe.transform(X_TEST)

    # Reshape data to be compatible with Conv1D input (samples, features, 1)
    X_TRAIN = X_TRAIN.reshape((-1, FEATURES, 1))
    X_TEST = X_TEST.reshape((-1, FEATURES, 1))

    return (X_TRAIN, Y_TRAIN), (X_TEST, Y_TEST)

# Flower client class
class FlowerClient(fl.client.NumPyClient):
    def __init__(self, x_train, y_train, x_val, y_val):
        self.model = get_model()
        self.x_train, self.y_train = x_train, y_train
        self.x_val, self.y_val = x_val, y_val

    def get_parameters(self, config):
        return self.model.get_weights()

    def fit(self, parameters, config):
        self.model.set_weights(parameters)
        self.model.fit(self.x_train, self.y_train, epochs=EPOCHS, batch_size=BATCH_SIZE, verbose=0)
        return self.model.get_weights(), len(self.x_train), {}

    def evaluate(self, parameters, config):
        self.model.set_weights(parameters)
        loss, acc = self.model.evaluate(self.x_val, self.y_val, batch_size=BATCH_SIZE, verbose=0)
        return loss, len(self.x_val), {"accuracy": acc}

# Function to create Flower clients
def client_fn(cid: str) -> fl.client.Client:
    idx = int(cid)
    x_train_part, y_train_part = partitions[idx]
    x_val_part, y_val_part = X_TEST, Y_TEST
    return FlowerClient(x_train_part, y_train_part, x_val_part, y_val_part)

# Function to define the model
def get_model():
    model = Sequential([
        Conv1D(16, kernel_size=3, activation='relu', input_shape=(FEATURES, 1)),
        MaxPool1D(pool_size=2),
        LSTM(8, return_sequences=False, dropout=0.2),
        Dense(8, activation='relu'),
        Dense(2, activation='softmax')  # Adjusted for two classes: BENIGN, DDos
    ])
    optimizer = Adam(learning_rate=0.005)
    model.compile(loss='sparse_categorical_crossentropy', optimizer=optimizer, metrics=['accuracy'])
    return model

# Load and partition data
(X_TRAIN, Y_TRAIN), (X_TEST, Y_TEST) = load_and_process_data()
partition_size = math.floor(len(X_TRAIN) / NUM_CLIENTS)
partitions = [(X_TRAIN[i * partition_size: (i + 1) * partition_size], Y_TRAIN[i * partition_size: (i + 1) * partition_size]) for i in range(NUM_CLIENTS)]

# Define evaluation function
def weighted_average(metrics: List[Tuple[int, dict]]) -> dict:
    accuracies = [num_examples * m["accuracy"] for num_examples, m in metrics]
    examples = [num_examples for num_examples, _ in metrics]
    return {"accuracy": sum(accuracies) / sum(examples)}

# Define global evaluation function
def evaluate_fn(server_round: int, parameters: List, config: dict) -> Tuple[float, dict]:
    model = get_model()
    model.set_weights(parameters)
    loss, acc = model.evaluate(X_TEST, Y_TEST, verbose=0)
    return loss, {"accuracy": acc}

# Set up Flower federated learning strategy
strategy = fl.server.strategy.FedAvg(
    fraction_fit=1.0,
    fraction_evaluate=0.5,
    min_fit_clients=NUM_CLIENTS,
    min_evaluate_clients=NUM_CLIENTS,
    min_available_clients=NUM_CLIENTS,
    evaluate_metrics_aggregation_fn=weighted_average,
    evaluate_fn=evaluate_fn
)

# # Start Flower simulation
# history = fl.simulation.start_simulation(
#     client_fn=client_fn,
#     num_clients=NUM_CLIENTS,
#     config=fl.server.ServerConfig(num_rounds=2),
#     strategy=strategy,
#     client_resources={"num_cpus": 2, "num_gpus": 1}
# )

  and should_run_async(code)


In [None]:
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

# Define evaluation function for client's local model
def evaluate_local_model(model, x_test, y_test):
    y_pred = np.argmax(model.predict(x_test), axis=1)
    accuracy = accuracy_score(y_test, y_pred)
    precision = precision_score(y_test, y_pred)
    recall = recall_score(y_test, y_pred)
    f1 = f1_score(y_test, y_pred)
    return accuracy, precision, recall, f1

# Define function to aggregate client results
def aggregate_client_results(client_results):
    acc_total, prec_total, rec_total, f1_total = 0, 0, 0, 0
    for acc, prec, rec, f1 in client_results:
        acc_total += acc
        prec_total += prec
        rec_total += rec
        f1_total += f1
    num_clients = len(client_results)
    return acc_total / num_clients, prec_total / num_clients, rec_total / num_clients, f1_total / num_clients

# Initialize lists to store metrics for each round
client_accuracy_history = []
client_precision_history = []
client_recall_history = []
client_f1_history = []

aggregated_accuracy_history = []
aggregated_precision_history = []
aggregated_recall_history = []
aggregated_f1_history = []

# Define evaluation function for Flower
# Define global evaluation function
def evaluate_fn(server_round: int, parameters: List, config: dict) -> Tuple[float, dict]:
    model = get_model()
    model.set_weights(parameters)
    loss, acc = model.evaluate(X_TEST, Y_TEST, verbose=0)
    print("Loss:", loss)
    print("Accuracy:", acc)

    # Evaluate client's local model
    client_results = []
    for client in history[server_round]["client_states"]:
        client_model = get_model()
        client_model.set_weights(client["parameters"])
        client_results.append(evaluate_local_model(client_model, X_TEST, Y_TEST))

    # Aggregate client results
    aggregated_results = aggregate_client_results(client_results)

    # Append metrics to history
    client_accuracy_history.append([acc for acc, _, _, _ in client_results])
    client_precision_history.append([prec for _, prec, _, _ in client_results])
    client_recall_history.append([rec for _, _, rec, _ in client_results])
    client_f1_history.append([f1 for _, _, _, f1 in client_results])

    aggregated_accuracy_history.append(aggregated_results[0])
    aggregated_precision_history.append(aggregated_results[1])
    aggregated_recall_history.append(aggregated_results[2])
    aggregated_f1_history.append(aggregated_results[3])

    return loss, {"accuracy": acc}


# Start Flower simulation
history = fl.simulation.start_simulation(
    client_fn=client_fn,
    num_clients=NUM_CLIENTS,
    config=fl.server.ServerConfig(num_rounds=2),
    strategy=strategy,
    client_resources={"num_cpus": 2, "num_gpus": 1}
)

# Visualize metrics
import matplotlib.pyplot as plt

# # Helper function to plot metrics
# def plot_metrics(history, ylabel, title):
#     plt.figure(figsize=(10, 6))
#     for i, (client_metric, aggregated_metric) in enumerate(zip(history, [aggregated_accuracy_history, aggregated_precision_history, aggregated_recall_history, aggregated_f1_history])):
#         plt.subplot(2, 2, i+1)
#         plt.plot(range(1, len(client_metric)+1), client_metric, marker='o', label='Client')
#         plt.plot(range(1, len(aggregated_metric)+1), aggregated_metric, marker='o', label='Aggregated')
#         plt.xlabel('Round')
#         plt.ylabel(ylabel[i])
#         plt.title(title[i])
#         plt.legend()
#     plt.tight_layout()
#     plt.show()

def plot_metrics(history, ylabel, title):
    plt.figure(figsize=(10, 6))
    for i, (client_metric, aggregated_metric) in enumerate(zip(history, [aggregated_accuracy_history, aggregated_precision_history, aggregated_recall_history, aggregated_f1_history])):
        print("Client:", client_metric)
        print("Aggregated:", aggregated_metric)
        plt.subplot(2, 2, i+1)
        plt.plot(range(1, len(client_metric)+1), client_metric, marker='o', label='Client')
        plt.plot(range(1, len(aggregated_metric)+1), aggregated_metric, marker='o', label='Aggregated')
        plt.xlabel('Round')
        plt.ylabel(ylabel[i])
        plt.title(title[i])
        plt.legend()
    plt.tight_layout()
    plt.show()

plot_metrics([client_accuracy_history, client_precision_history, client_recall_history, client_f1_history],
             ['Accuracy', 'Precision', 'Recall', 'F1 Score'],
             ['Client Accuracy', 'Client Precision', 'Client Recall', 'Client F1 Score'])

  and should_run_async(code)
[92mINFO [0m:      Starting Flower simulation, config: num_rounds=2, no round_timeout
INFO:flwr:Starting Flower simulation, config: num_rounds=2, no round_timeout
  self.pid = _posixsubprocess.fork_exec(
2024-05-06 14:34:03,948	INFO worker.py:1621 -- Started a local Ray instance.
[92mINFO [0m:      Flower VCE: Ray initialized with resources: {'CPU': 2.0, 'node:__internal_head__': 1.0, 'node:172.28.0.12': 1.0, 'memory': 7535114651.0, 'object_store_memory': 3767557324.0}
INFO:flwr:Flower VCE: Ray initialized with resources: {'CPU': 2.0, 'node:__internal_head__': 1.0, 'node:172.28.0.12': 1.0, 'memory': 7535114651.0, 'object_store_memory': 3767557324.0}
[92mINFO [0m:      Optimize your simulation with Flower VCE: https://flower.ai/docs/framework/how-to-run-simulations.html
INFO:flwr:Optimize your simulation with Flower VCE: https://flower.ai/docs/framework/how-to-run-simulations.html
[92mINFO [0m:      Flower VCE: Resources for each Virtual Client: {'nu

ValueError: ActorPool is empty. Stopping Simulation. Check 'client_resources' passed to `start_simulation`

In [None]:
! pip install -U flwr["simulation"]



In [None]:

import math
from typing import List, Tuple
import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow.keras import Sequential
from tensorflow.keras.layers import Conv1D, MaxPool1D, LSTM, Dense
from tensorflow.keras.optimizers import Adam
from sklearn.tree import DecisionTreeClassifier
from sklearn.feature_selection import RFE
import flwr as fl

# Configuration constants
NUM_CLIENTS = 2
FEATURES = 46
EPOCHS = 1
BATCH_SIZE = 32

# Function to load and process data
def load_and_process_data():
    # Load datasets
    df = pd.read_csv("/content/part-00001-363d1ba3-8ab5-4f96-bc25-4d5862db7cb9-c000.csv")
    df_test = pd.read_csv("/content/part-00005-363d1ba3-8ab5-4f96-bc25-4d5862db7cb9-c000.csv")

    # Label mapping
    def change_label(df):
        df.label.replace(['DDoS-ICMP_Flood','DDoS-UDP_Flood','DDoS-TCP_Flood','DDoS-PSHACK_Flood','DDoS-SYN_Flood','DDoS-RSTFINFlood','DDoS-SynonymousIP_Flood','DDoS-ICMP_Fragmentation','DDoS-UDP_Fragmentation','DDoS-ACK_Fragmentation','DDoS-HTTP_Flood','DDoS-SlowLoris'],'DDos',inplace=True)
        df.label.replace(['DoS-UDP_Flood','DoS-TCP_Flood','DoS-SYN_Flood','DoS-HTTP_Flood'],'DoS',inplace=True)
        df.label.replace(['Recon-HostDiscovery','Recon-OSScan','Recon-PortScan','Recon-PingSweep','VulnerabilityScan'],'Recon',inplace=True)
        df.label.replace(['MITM-ArpSpoofing','DNS_Spoofing'],'Spoofing',inplace=True)
        df.label.replace(['DictionaryBruteForce'],'BruteForce',inplace=True)
        df.label.replace(['BrowserHijacking','XSS','Uploading_Attack','SqlInjection','CommandInjection','Backdoor_Malware'],'Web-based',inplace=True)
        df.label.replace(['Mirai-greeth_flood','Mirai-udpplain','Mirai-greip_flood'],'Mirai',inplace=True)
        df.label.replace(['BenignTraffic'],'BENIGN',inplace=True)


    # Apply label changes to training and testing datasets
    change_label(df)
    change_label(df_test)

    # Filter out classes that we are not interested in
    df_DDOS = df[df['label'].isin(['DDos', 'BENIGN'])]
    df_DDOS_test = df_test[df_test['label'].isin(['DDos', 'BENIGN'])]

    # Map class labels to numeric
    class_mapping = {'BENIGN': 0, 'DDos': 1}
    Y_TRAIN = df_DDOS['label'].map(class_mapping)
    Y_TEST = df_DDOS_test['label'].map(class_mapping)

    # Drop the label column to obtain only features
    X_TRAIN = df_DDOS.drop('label', axis=1).copy()
    X_TEST = df_DDOS_test.drop('label', axis=1).copy()

    # Convert columns to float32
    X_TRAIN = X_TRAIN.astype(np.float32)
    X_TEST = X_TEST.astype(np.float32)

    # Perform feature selection with RFE
    clf = DecisionTreeClassifier(random_state=0)
    rfe = RFE(estimator=clf, n_features_to_select=FEATURES, step=1)
    rfe.fit(X_TRAIN, Y_TRAIN)

    # Verify selected feature count
    if rfe.n_features_ != FEATURES:
        raise ValueError(f"Expected {FEATURES} features but got {rfe.n_features_}")

    # Transform training and testing data to have the specified number of features
    X_TRAIN = rfe.transform(X_TRAIN)
    X_TEST = rfe.transform(X_TEST)

    # Reshape data to be compatible with Conv1D input (samples, features, 1)
    X_TRAIN = X_TRAIN.reshape((-1, FEATURES, 1))
    X_TEST = X_TEST.reshape((-1, FEATURES, 1))

    return (X_TRAIN, Y_TRAIN), (X_TEST, Y_TEST)

# Flower client class
class FlowerClient(fl.client.NumPyClient):
    def __init__(self, x_train, y_train, x_val, y_val):
        self.model = get_model()
        self.x_train, self.y_train = x_train, y_train
        self.x_val, self.y_val = x_val, y_val

    def get_parameters(self, config):
        return self.model.get_weights()

    def fit(self, parameters, config):
        self.model.set_weights(parameters)
        self.model.fit(self.x_train, self.y_train, epochs=EPOCHS, batch_size=BATCH_SIZE, verbose=0)
        return self.model.get_weights(), len(self.x_train), {}

    def evaluate(self, parameters, config):
        self.model.set_weights(parameters)
        loss, acc = self.model.evaluate(self.x_val, self.y_val, batch_size=BATCH_SIZE, verbose=0)
        return loss, len(self.x_val), {"accuracy": acc}

# Function to create Flower clients
def client_fn(cid: str) -> fl.client.Client:
    idx = int(cid)
    x_train_part, y_train_part = partitions[idx]
    x_val_part, y_val_part = X_TEST, Y_TEST
    return FlowerClient(x_train_part, y_train_part, x_val_part, y_val_part)

# Function to define the model
def get_model():
    model = Sequential([
        Conv1D(16, kernel_size=3, activation='relu', input_shape=(FEATURES, 1)),
        MaxPool1D(pool_size=2),
        LSTM(8, return_sequences=False, dropout=0.2),
        Dense(8, activation='relu'),
        Dense(2, activation='softmax')  # Adjusted for two classes: BENIGN, DDos
    ])
    optimizer = Adam(learning_rate=0.005)
    model.compile(loss='sparse_categorical_crossentropy', optimizer=optimizer, metrics=['accuracy'])
    return model

# Load and partition data
(X_TRAIN, Y_TRAIN), (X_TEST, Y_TEST) = load_and_process_data()
partition_size = math.floor(len(X_TRAIN) / NUM_CLIENTS)
partitions = [(X_TRAIN[i * partition_size: (i + 1) * partition_size], Y_TRAIN[i * partition_size: (i + 1) * partition_size]) for i in range(NUM_CLIENTS)]

# Define evaluation function
def weighted_average(metrics: List[Tuple[int, dict]]) -> dict:
    accuracies = [num_examples * m["accuracy"] for num_examples, m in metrics]
    examples = [num_examples for num_examples, _ in metrics]
    return {"accuracy": sum(accuracies) / sum(examples)}

# Define global evaluation function
def evaluate_fn(server_round: int, parameters: List, config: dict) -> Tuple[float, dict]:
    model = get_model()
    model.set_weights(parameters)
    loss, acc = model.evaluate(X_TEST, Y_TEST, verbose=0)
    return loss, {"accuracy": acc}

# Set up Flower federated learning strategy
strategy = fl.server.strategy.FedAvg(
    fraction_fit=1.0,
    fraction_evaluate=0.5,
    min_fit_clients=NUM_CLIENTS,
    min_evaluate_clients=NUM_CLIENTS,
    min_available_clients=NUM_CLIENTS,
    evaluate_metrics_aggregation_fn=weighted_average,
    evaluate_fn=evaluate_fn
)

# # Start Flower simulation
# history = fl.simulation.start_simulation(
#     client_fn=client_fn,
#     num_clients=NUM_CLIENTS,
#     config=fl.server.ServerConfig(num_rounds=2),
#     strategy=strategy,
#     client_resources={"num_cpus": 2, "num_gpus": 1}
# )
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

# Define evaluation function for client's local model
def evaluate_local_model(model, x_test, y_test):
    y_pred = np.argmax(model.predict(x_test), axis=1)
    accuracy = accuracy_score(y_test, y_pred)
    precision = precision_score(y_test, y_pred)
    recall = recall_score(y_test, y_pred)
    f1 = f1_score(y_test, y_pred)
    return accuracy, precision, recall, f1

# Define function to aggregate client results
def aggregate_client_results(client_results):
    acc_total, prec_total, rec_total, f1_total = 0, 0, 0, 0
    for acc, prec, rec, f1 in client_results:
        acc_total += acc
        prec_total += prec
        rec_total += rec
        f1_total += f1
    num_clients = len(client_results)
    return acc_total / num_clients, prec_total / num_clients, rec_total / num_clients, f1_total / num_clients

# Initialize lists to store metrics for each round
client_accuracy_history = []
client_precision_history = []
client_recall_history = []
client_f1_history = []

aggregated_accuracy_history = []
aggregated_precision_history = []
aggregated_recall_history = []
aggregated_f1_history = []

# Define evaluation function for Flower
# Define global evaluation function
def evaluate_fn(server_round: int, parameters: List, config: dict) -> Tuple[float, dict]:
    model = get_model()
    model.set_weights(parameters)
    loss, acc = model.evaluate(X_TEST, Y_TEST, verbose=0)
    print("Loss:", loss)
    print("Accuracy:", acc)

    # Evaluate client's local model
    client_results = []
    for client in history[server_round]["client_states"]:
        client_model = get_model()
        client_model.set_weights(client["parameters"])
        client_results.append(evaluate_local_model(client_model, X_TEST, Y_TEST))

    # Aggregate client results
    aggregated_results = aggregate_client_results(client_results)

    # Append metrics to history
    client_accuracy_history.append([acc for acc, _, _, _ in client_results])
    client_precision_history.append([prec for _, prec, _, _ in client_results])
    client_recall_history.append([rec for _, _, rec, _ in client_results])
    client_f1_history.append([f1 for _, _, _, f1 in client_results])

    aggregated_accuracy_history.append(aggregated_results[0])
    aggregated_precision_history.append(aggregated_results[1])
    aggregated_recall_history.append(aggregated_results[2])
    aggregated_f1_history.append(aggregated_results[3])

    return loss, {"accuracy": acc}


# Start Flower simulation
history = fl.simulation.start_simulation(
    client_fn=client_fn,
    num_clients=NUM_CLIENTS,
    config=fl.server.ServerConfig(num_rounds=2),
    strategy=strategy,
    client_resources={"num_cpus": 2, "num_gpus": 1}
)

# Visualize metrics
import matplotlib.pyplot as plt

# # Helper function to plot metrics
# def plot_metrics(history, ylabel, title):
#     plt.figure(figsize=(10, 6))
#     for i, (client_metric, aggregated_metric) in enumerate(zip(history, [aggregated_accuracy_history, aggregated_precision_history, aggregated_recall_history, aggregated_f1_history])):
#         plt.subplot(2, 2, i+1)
#         plt.plot(range(1, len(client_metric)+1), client_metric, marker='o', label='Client')
#         plt.plot(range(1, len(aggregated_metric)+1), aggregated_metric, marker='o', label='Aggregated')
#         plt.xlabel('Round')
#         plt.ylabel(ylabel[i])
#         plt.title(title[i])
#         plt.legend()
#     plt.tight_layout()
#     plt.show()

def plot_metrics(history, ylabel, title):
    plt.figure(figsize=(10, 6))
    for i, (client_metric, aggregated_metric) in enumerate(zip(history, [aggregated_accuracy_history, aggregated_precision_history, aggregated_recall_history, aggregated_f1_history])):
        print("Client:", client_metric)
        print("Aggregated:", aggregated_metric)
        plt.subplot(2, 2, i+1)
        plt.plot(range(1, len(client_metric)+1), client_metric, marker='o', label='Client')
        plt.plot(range(1, len(aggregated_metric)+1), aggregated_metric, marker='o', label='Aggregated')
        plt.xlabel('Round')
        plt.ylabel(ylabel[i])
        plt.title(title[i])
        plt.legend()
    plt.tight_layout()
    plt.show()

plot_metrics([client_accuracy_history, client_precision_history, client_recall_history, client_f1_history],
             ['Accuracy', 'Precision', 'Recall', 'F1 Score'],
             ['Client Accuracy', 'Client Precision', 'Client Recall', 'Client F1 Score'])


[92mINFO [0m:      Starting Flower simulation, config: num_rounds=2, no round_timeout
INFO:flwr:Starting Flower simulation, config: num_rounds=2, no round_timeout
  self.pid = _posixsubprocess.fork_exec(
2024-05-06 14:37:22,251	INFO worker.py:1621 -- Started a local Ray instance.
[92mINFO [0m:      Flower VCE: Ray initialized with resources: {'object_store_memory': 3767041228.0, 'memory': 7534082459.0, 'CPU': 2.0, 'node:172.28.0.12': 1.0, 'node:__internal_head__': 1.0}
INFO:flwr:Flower VCE: Ray initialized with resources: {'object_store_memory': 3767041228.0, 'memory': 7534082459.0, 'CPU': 2.0, 'node:172.28.0.12': 1.0, 'node:__internal_head__': 1.0}
[92mINFO [0m:      Optimize your simulation with Flower VCE: https://flower.ai/docs/framework/how-to-run-simulations.html
INFO:flwr:Optimize your simulation with Flower VCE: https://flower.ai/docs/framework/how-to-run-simulations.html
[92mINFO [0m:      Flower VCE: Resources for each Virtual Client: {'num_cpus': 2, 'num_gpus': 1}
IN

ValueError: ActorPool is empty. Stopping Simulation. Check 'client_resources' passed to `start_simulation`