### Download and make the dataset ready in Kaggle 


In [1]:

# ## uncomment if The zip file of the dataset isn't downloaded,extraced 
# !pip install gdown
# Copy the link. The file ID is the long string of characters between d/ and /view.
#For example, in the URL https://drive.google.com/file/d/1aBcDeFgHiJkLmNoPqRsTuVwXyZ/view?usp=sharing, 
#the file ID is 1aBcDeFgHiJkLmNoPqRsTuVwXyZ
# !mkdir /kaggle/tmp
# !gdown  1pzXpA5Cz0DJmjRsLxlqRNnJq-kOUvojb -O /kaggle/tmp/Labeled_CICMODBUS2023.zip
# !unzip /kaggle/tmp/Labeled_CICMODBUS2023.zip -d /kaggle/working/

# # ## uncomment if the python modules (modbus.py,utils.py ,...) not cloned  and added to the path 

# !git clone https://github.com/hamid-rd/FLBased-ICS-NIDS.git
# import sys
# # Add the repository folder to the Python path
# repo_path = '/kaggle/working/FLBased-ICS-NIDS'
# repo_input_path = '/kaggle/input/training/FLBased-ICS-NIDS'
# dataset_path = '/kaggle/input/training/'

# for path in {repo_path,repo_input_path,dataset_path}:
#     if path not in sys.path:
#         sys.path.append(path)


In [2]:
# To test if every thing is okay (modbus.py class and correct number of founded csv files )
from modbus import ModbusDataset,ModbusFlowStream

# dataset_directory = "/kaggle/working/ModbusDataset" 
# dataset_directory = "/kaggle/input/training/ModbusDataset" 
dataset_directory = "dataset" 

modbus = ModbusDataset(dataset_directory,"ready")
modbus.summary_print()

# Don't forget to save version in kaggle (to save outputs written on the disk (/kaggle/working/))  

 The CIC Modbus Dataset contains network (pcap) captures and attack logs from a simulated substation network.
                The dataset is categorized into two groups: an attack dataset and a benign dataset
                The attack dataset includes network traffic captures that simulate various types of Modbus protocol attacks in a substation environment.
                The attacks are reconnaissance, query flooding, loading payloads, delay response, modify length parameters, false data injection, stacking Modbus frames, brute force write and baseline replay.
                These attacks are based of some techniques in the MITRE ICS ATT&CK framework.
                On the other hand, the benign dataset consists of normal network traffic captures representing legitimate Modbus communication within the substation network.
                The purpose of this dataset is to facilitate research, analysis, and development of intrusion detection systems, anomaly detection algorithms and

### Unsupervised GRU-VAE training  

In [3]:
import torch 
import torch.nn as nn
import torch.nn.functional as F
import numpy as np # For standard deviation calculation
from modbus import ModbusDataset,ModbusFlowStream
from sklearn.metrics import accuracy_score, f1_score, confusion_matrix,recall_score
import torch.optim as optim
from torch.utils.data import DataLoader
import time
from utils import load_scalers
from random import SystemRandom
from sklearn.model_selection import train_test_split
import itertools
import torch.nn.init as init


def compute_threshold(mse_values,k=1):

    """
    K-SIGMA
    Computes the anomaly detection threshold (for marking sample as Intrusion if the IS was greater )
    based on the mean and standard deviation of Mean Squared Error (MSE) values.
    Formula: thr = mean(MSE) + std(MSE)
    Args:
    mse_values (torch.Tensor or list/np.array): A tensor or list of MSE values

                            obtained from the validation set.
    Returns:
    float: The calculated threshold.
    float: The calculated std.

    """
    if not isinstance(mse_values, torch.Tensor):
        mse_values = torch.tensor(mse_values, dtype=torch.float32)
    if mse_values.numel() == 0:
        return 0.0
    mean_mse = torch.mean(mse_values)
    std_mse = torch.std(mse_values)
    print("-----------mse_loss mean : ",f"{mean_mse.item():.4f}","std:",f"{std_mse.item():.4f}")
    threshold = mean_mse + k*std_mse
    return threshold.item(),std_mse.item()

def vae_loss_function(recon_x, x, mu, logvar,beta =1):
    """
    VAE loss function.
    """
    BCE = nn.functional.mse_loss(recon_x, x, reduction='sum')
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return (BCE + beta*KLD)


In [4]:
# dataset_directory = "/kaggle/input/training/ModbusDataset" # change this to the folder contain benign and attack subdirs
dataset_directory = "dataset" 
modbus = ModbusDataset(dataset_directory,"ready")
modbus.summary_print()

 The CIC Modbus Dataset contains network (pcap) captures and attack logs from a simulated substation network.
                The dataset is categorized into two groups: an attack dataset and a benign dataset
                The attack dataset includes network traffic captures that simulate various types of Modbus protocol attacks in a substation environment.
                The attacks are reconnaissance, query flooding, loading payloads, delay response, modify length parameters, false data injection, stacking Modbus frames, brute force write and baseline replay.
                These attacks are based of some techniques in the MITRE ICS ATT&CK framework.
                On the other hand, the benign dataset consists of normal network traffic captures representing legitimate Modbus communication within the substation network.
                The purpose of this dataset is to facilitate research, analysis, and development of intrusion detection systems, anomaly detection algorithms and

In [5]:
def _init_weights( module):
    ## for one layer apply Xavier Initialization
    if isinstance(module, nn.Linear):
        init.xavier_normal_(module.weight)
        if module.bias is not None:
            init.zeros_(module.bias)
    return module
# AutoEncoder (AE)
class AE(nn.Module):
    """
    Encoder: (89-64-32)
    Decoder: (32-64-89)
    """
    def __init__(self,input_dim=89):
        super(AE, self).__init__()
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, 64),
            nn.ReLU(),
            nn.Linear(64, 32),
            nn.ReLU(),
        )
        self.decoder = nn.Sequential(
            nn.Linear(32, 64),
            nn.ReLU(),
            nn.Linear(64, input_dim),
            nn.Sigmoid()
        )

    def forward(self, x):
        z = self.encoder(x)
        x_recon = self.decoder(z)
        return x_recon



In [6]:
from collections import Counter

def train_eval(model,train_dataloader,val_dataloader,test_dataloader,learning_rates= [5e-6,1e-7,5e-5,1e-5,1e-6],
               weight_decays=[1e-5,1e-4,1e-7],shuffle_files=True,num_epochs=20,eval_epoch=4,criterion_method="mse"):
    device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model=model.to(device)
    if criterion_method=="bce":
        criterion = nn.BCELoss(reduction='sum').to(device)
        eval_criterion = nn.BCELoss(reduction='none').to(device)
    else: #mse
        criterion = nn.MSELoss(reduction='sum').to(device)
        eval_criterion = nn.MSELoss(reduction='none').to(device)
    best_f1=0 #to save best version of the model during test
    best_recall=0 #to save best version of the model during test

    for lr, wd in itertools.product(learning_rates, weight_decays):
        if model._get_name()=="AdversarialAutoencoder":
            adversarial_criterion= nn.BCELoss(reduction="sum")
            optimizer_D = optim.SGD(model.discriminator.parameters(), lr=lr, weight_decay=wd)
            optimizer_G =  optim.SGD(list(model.encoder.parameters()) + list(model.decoder.parameters()), lr=lr, weight_decay=wd)
        else:
            AE_optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=wd)
            ### new code
            # AE_optimizer = optim.SGD(model.parameters(), lr=lr, weight_decay=wd)

        print(f"\n==================  lr={lr}, wd={wd} ==================")
        model.apply(_init_weights)
        for epoch in range(num_epochs):
            time_1 = time.time()
            model.train()
            train_loss = 0
            ## for AAE
            Discriminator_loss = 0
            if shuffle_files:
                sys_rand = SystemRandom()
                sys_rand.shuffle(train_dataloader.dataset.csv_files)
            for sequences, labels in train_dataloader:
                sequences=sequences.squeeze().to(device)
                if labels.sum()!=0:
                    continue
                if model._get_name()=="AdversarialAutoencoder":
                    target_ones= torch.ones(sequences.size(0), 1,device=device,dtype=torch.float)
                    target_zeros= torch.zeros(sequences.size(0), 1,device=device,dtype=torch.float)
                    random_latent = torch.randn(sequences.size(0), 2, device=device)
                    optimizer_G.zero_grad()
                    fake_z,decoded_seq = model(sequences)
                    G_loss = 0.001*adversarial_criterion(model.discriminator(fake_z),target_ones ) + 0.999*criterion(decoded_seq, sequences)
                    G_loss.backward()
                    optimizer_G.step()
                    # 2) discriminator loss
                    optimizer_D.zero_grad()
                    real_loss = adversarial_criterion(model.discriminator(random_latent), target_ones)
                    fake_loss = adversarial_criterion(model.discriminator(fake_z.detach()),  target_zeros)
                    D_loss = 0.5*(real_loss + fake_loss)
                    D_loss.backward()
                    optimizer_D.step()
                    train_loss+=G_loss.item()
                    Discriminator_loss+=D_loss.item()   
                else:
                    AE_optimizer.zero_grad()
                    if model._get_name()=="AE":
                        recon = model(sequences)
                        loss = criterion(recon, sequences) / sequences.size(0)
                    elif model._get_name()=="VAE" or model._get_name()=="GRUVAE":
                        recon, mu, logvar = model(sequences)
                        loss = vae_loss_function(recon, sequences, mu, logvar) /sequences.size(0)
                    loss.backward()
                    AE_optimizer.step()
                    train_loss += loss.item()
            print(f"Train : time {(time.time()-time_1):.2f} s",
            f"Epoch {epoch+1}")
            if model._get_name()=="AdversarialAutoencoder":
                print(f"Generator Loss: {train_loss / len(train_dataloader):.4f}",
                    f"Discriminator Loss: {Discriminator_loss / len(train_dataloader):.4f}")
            else:
                print(f"Train Loss: {train_loss / len(train_dataloader):.4f}")
            # Evaluate part
            if (epoch + 1) % eval_epoch == 0:
                model.eval() 
                all_val_losses = []
                all_val_labels = []
                print(f"--- Running Evaluation for Epoch {epoch+1} lr ={lr} wd {wd} ---")
                with torch.no_grad():
                    for sequences, labels in val_dataloader:
                        sequences = sequences.squeeze().to(device) 
                        if labels.sum()!=0:
                            continue
                        if criterion_method=="bce":
                            ## may test features be greater than 1 after scaling 
                            sequences=torch.clamp(sequences, min=0.0, max=1.0)      
                        if model._get_name()=="AE":
                            recon = model(sequences)
                        elif model._get_name()=="VAE" or model._get_name()=="GRUVAE" :
                            recon, _, _ = model(sequences)
                        elif model._get_name()=="AdversarialAutoencoder":
                            _,recon= model(sequences)

                        val_loss = eval_criterion(recon, sequences)
                        if val_loss.dim() > 1:
                            val_loss = val_loss
                        else:
                            val_loss = val_loss.unsqueeze(dim=0)
                            labels = labels.unsqueeze(dim=0)
                        val_loss = val_loss.sum(dim=1)
                        all_val_losses.extend(val_loss.cpu().numpy())
                        all_val_labels.extend(labels.flatten().cpu().numpy())            
                threshold_1,std_mse = compute_threshold(all_val_losses,k=0)

                all_val_losses = np.array(all_val_losses).squeeze()  
                all_val_labels = np.array(all_val_labels).squeeze()  
                # If intrusion score > threshold, predict 1 (intrusion), else 0 (benign)
                # For FDR, get True Positives (TP) and False Positives (FP)
                
                predictions = (all_val_losses > threshold_1).astype(int)
                accuracy = accuracy_score(all_val_labels, predictions)
                print(f"Val: Accuracy: {accuracy:.4f}  ")
                model.eval() 
                all_test_losses = []
                all_test_labels = []
                with torch.no_grad():
                    for sequences, labels in test_dataloader:
                        sequences = sequences.squeeze().to(device)
                        labels = labels.squeeze().to(device)
                        if criterion_method=="bce":
                            ## may test features be greater than 1 after scaling 
                            sequences=torch.clamp(sequences, min=0.0, max=1.0)
                        if model._get_name()=="AE":
                            recon = model(sequences)
                        elif model._get_name()=="VAE"  or model._get_name()=="GRUVAE":
                            recon, mu, logvar = model(sequences)
                        elif model._get_name()=="AdversarialAutoencoder":
                            _,recon= model(sequences)

                        intrusion_scores = eval_criterion(recon, sequences)
                        if intrusion_scores.dim() > 1:
                            intrusion_scores = intrusion_scores
                        else:
                            intrusion_scores = intrusion_scores.unsqueeze(dim=0)
                            labels = labels.unsqueeze(dim=0)
                        intrusion_scores = intrusion_scores.sum(dim=1)
                        all_test_losses.extend(intrusion_scores.cpu().numpy())
                        all_test_labels.extend(labels.cpu().numpy())
                        ### remove labels =1 (don't consider Brute Force attack beacuse of conflicted labeling with normal data)
                        # mask = labels != 1

                        # filtered_scores = intrusion_scores[mask]
                        # filtered_labels = labels[mask]
                
                        # # Move to CPU and convert to numpy
                        # all_test_losses.extend(filtered_scores.cpu().numpy())
                        # all_test_labels.extend(filtered_labels.cpu().numpy())


                all_test_losses = np.array(all_test_losses)
                all_test_labels = np.array(all_test_labels)
                # for threshold in {threshold_1,threshold_2,threshold_3,threshold_10,threshold_100}:
                temp_best_recall =best_recall
                temp_best_f1 =best_f1

                for k in {1,3}:
                    threshold=threshold_1+k*std_mse
                    print(f" K: {k} K-SIGMA Threshold : ---thr {threshold:.4}")
                    predictions = (all_test_losses > threshold).astype(int)
                    binary_test_labels = (all_test_labels != 0).astype(int)
                    # --- Start of new code ---

                    # Find the indices where the prediction was incorrect
                    misclassified_indices = np.where(binary_test_labels != predictions)[0]

                    # Get the original labels for those misclassified instances
                    misclassified_original_labels = all_test_labels[misclassified_indices]

                    # To get a summary count of which labels were misclassified
                    print(Counter(predictions),Counter(binary_test_labels))
                    print(f"Counts of  labels: {dict(sorted(Counter(all_test_labels).items()))}")
                    print(f"Counts of misclassified original labels: {dict(sorted(Counter(misclassified_original_labels).items()))}")

                    accuracy = accuracy_score(binary_test_labels, predictions)
                    f1 = f1_score(binary_test_labels, predictions, zero_division=0)
                    recall = recall_score(binary_test_labels, predictions,zero_division=0)
                    _, fp, _, tp = confusion_matrix(binary_test_labels, predictions, labels=[0, 1]).ravel()
                    # FDR = FP / (FP + TP) 
                    if (fp + tp) == 0:
                        fdr = 0.0 
                    else:
                        fdr = fp / (fp + tp)
                    print(f"Test : Accuracy: {accuracy:.4f} Recall : {recall:.4f} FDR: {fdr:.4f}  F1-score: {f1:.4f}  ")
                    !mkdir best_models -p
                    if k==3 and f1>best_f1 :
                        best_f1=f1
                    elif k==1 and recall>best_recall:
                        best_recall=recall
                if (best_recall>temp_best_recall or best_f1 > temp_best_f1):
                    save_path ="best_models/"+model._get_name()+"_f1_"+f"{best_f1:.2f}" +"_recall_"+f"{best_recall:.2f}" +"_.pth"
                    torch.save(model.state_dict(),save_path)
                    print("model",model._get_name(),"is saved in" ,save_path )


#### Centralized : external scenario -> ied1a node

In [7]:
# train_files=[col for col in modbus.dataset["benign_dataset_dir"] if col.find("network-wide")!=-1]
train_files=[col for col in modbus.dataset["benign_dataset_dir"] if col.find("network-wide")!=-1][:]
val_files = [col for col in modbus.dataset["benign_dataset_dir"] if col.find("ied1a")!=-1][:]
test_files=[col for col in modbus.dataset["attack_dataset_dir"]["external"] if col.find("ied1a")!=-1]
sys_rand = SystemRandom()

sys_rand.shuffle(train_files)
sys_rand.shuffle(val_files)
sys_rand.shuffle(test_files)


print("ied1b comp ied attack ->\n test: ",len(test_files),test_files)
print("----------- Network-wide number of csv files -> \n ----------- train :",len(train_files),train_files,"\n -------- valid:",len(val_files),val_files)

ied1b comp ied attack ->
 test:  1 ['dataset/ModbusDataset/attack/external/ied1a/ied1a-network-capture/ready/veth4edc015-0-labeled.csv']
----------- Network-wide number of csv files -> 
 ----------- train : 19 ['dataset/ModbusDataset/benign/network-wide-pcap-capture/network-wide/ready/network-wide-normal-16-labeled.csv', 'dataset/ModbusDataset/benign/network-wide-pcap-capture/network-wide/ready/network-wide-normal-28-labeled.csv', 'dataset/ModbusDataset/benign/network-wide-pcap-capture/network-wide/ready/network-wide-normal-22-labeled.csv', 'dataset/ModbusDataset/benign/network-wide-pcap-capture/network-wide/ready/network-wide-normal-18-labeled.csv', 'dataset/ModbusDataset/benign/network-wide-pcap-capture/network-wide/ready/network-wide-normal-30-labeled.csv', 'dataset/ModbusDataset/benign/network-wide-pcap-capture/network-wide/ready/network-wide-normal-20-labeled.csv', 'dataset/ModbusDataset/benign/network-wide-pcap-capture/network-wide/ready/network-wide-normal-26-labeled.csv', 'data

In [8]:
### Try The Copy-on-Write (CoW) technique, share the same single copy of the dataset in memory with multiple forked workers from the main process
# Ensure to have enough memory for saving large tensors in the ram 
###### else use chunk_size =1 and read the files iteratively

use_cow=True
window_size=1
loaded_scalers=load_scalers('fitted_scalers')

Successfully loaded scalers for 'network-wide'


In [9]:


# This cell Initializes and returns train, validation, and test dataloaders.

# This function supports two strategies for data loading:
# 1. Copy-on-Write (use_cow=True): Loads the entire dataset into RAM. This is fast
#     but memory-intensive. It allows multiple worker processes to share the same
#     dataset copy in memory, which is efficient for multiprocessing.
# 2. Iterative (use_cow=False): Reads data from files in small chunks. This is
#     slower but uses significantly less memory, suitable for very large datasets
#     that don't fit in RAM.

#     train_files (list): List of file paths for the training dataset.
#     val_files (list): List of file paths for the validation dataset.
#     test_files (list): List of file paths for the test dataset.
#     window_size (int): The size of the sliding window for sequence data.
#     use_cow (bool, optional): If True, uses the Copy-on-Write strategy. 
#                                 Defaults to True.

#      return        (train_dataloader, val_dataloader, test_dataloader)

if use_cow==True:
    large_chunk_size = modbus.dataset["metadata"]["founded_files_num"]["total_dataset_num"]

    dataset_configs = {
        "train": {"files": train_files},
        "val": {"files": val_files},
        "test": {"files": test_files},
    }
    datasets = {}
    ae_datasets = {}

    print("Cow Processing datasets...")

    for name, config in dataset_configs.items():
        print(f"  - Creating '{name}' dataset...")
        
        # 1. Create the primary ModbusFlowStream dataset
        datasets[name] = ModbusFlowStream(
            shuffle=False,
            chunk_size=large_chunk_size,
            batch_size=1,
            csv_files=config["files"],
            scalers=loaded_scalers['network-wide']['min_max_scalers'],
            window_size=window_size
        )
        
        # 2. Call __getitem__(0) once to load the entire dataset chunk into memory
        datasets[name].__getitem__(0)
        
        # used for specific AE training/evaluation without re-reading files.
        ae_datasets[name] = ModbusFlowStream(
            shuffle=False,  # AE data is typically processed in order
            chunk_size=large_chunk_size,
            batch_size=1,
            csv_files=[],  # No CSV files needed as we copy the data directly
            scalers=None,   # Data is already scaled from the original dataset
            window_size=window_size
        )
        
        # 4. Manually copy the loaded data and properties to the AE dataset

        ae_datasets[name].current_chunk_data =  datasets[name].current_chunk_data
        ae_datasets[name].current_len_chunk_data =  datasets[name].current_len_chunk_data
        ae_datasets[name].current_chunk_labels =  datasets[name].current_chunk_labels
        ae_datasets[name].total_batches =  datasets[name].total_batches
        
        print(f"  - Finished '{name}' dataset.")
    train_dataloader=DataLoader(ae_datasets['train'],batch_size=64,shuffle=True,num_workers=4,persistent_workers=True,prefetch_factor=2,pin_memory=True)
    val_dataloader=DataLoader(ae_datasets['val'],batch_size=64,shuffle=False,num_workers=4,persistent_workers=True,prefetch_factor=2,pin_memory=True)
    test_dataloader=DataLoader(ae_datasets['test'],batch_size=64,shuffle=False,num_workers=4,persistent_workers=True,prefetch_factor=2,pin_memory=True)

else :
    train_dataloader=DataLoader(ModbusFlowStream( 
        shuffle=True,chunk_size=1,batch_size=64,csv_files=train_files,scalers=loaded_scalers['network-wide']['min_max_scalers'],window_size=window_size
    ),  batch_size=1,shuffle=False)
    val_dataloader=DataLoader(ModbusFlowStream( 
        shuffle=False,chunk_size=1,batch_size=64,csv_files=val_files,scalers=loaded_scalers['network-wide']['min_max_scalers'],window_size=window_size
    ),batch_size=1,shuffle=False)
    test_dataloader=DataLoader(ModbusFlowStream(shuffle=False,chunk_size=1,batch_size=64,csv_files=test_files,scalers=loaded_scalers['network-wide']['min_max_scalers'],window_size=window_size),
                               batch_size=1,shuffle=False)


Cow Processing datasets...
  - Creating 'train' dataset...
  - Finished 'train' dataset.
  - Creating 'val' dataset...
  - Finished 'val' dataset.
  - Creating 'test' dataset...
  - Finished 'test' dataset.


In [10]:
print(len(train_dataloader),len(val_dataloader),len(test_dataloader))

44101 19154 1960


In [None]:
# train_eval(AE_model,AE_train_dataloader,AE_val_dataloader,AE_test_dataloader,shuffle_files=True,num_epochs=20,eval_epoch=1,criterion_method="bce",learning_rates=[5e-4],weight_decays=[0])
AE_model = AE(input_dim=76)
train_eval(AE_model,train_dataloader,val_dataloader,test_dataloader,shuffle_files=False,num_epochs=6,eval_epoch=1,criterion_method="mse",learning_rates=[5e-5,1e-4,1e-5,1e-6],weight_decays=[1e-4])



Train : time 163.55 s Epoch 1
Train Loss: 0.2204
--- Running Evaluation for Epoch 1 lr =5e-05 wd 0.0001 ---
-----------mse_loss mean :  0.0031 std: 0.0626
Val: Accuracy: 0.9751  
 K: 1 K-SIGMA Threshold : ---thr 0.06569
Counter({1: 65686, 0: 59746}) Counter({0: 89417, 1: 36015})
Counts of  labels: {0: 89417, 1: 35978, 2: 1, 3: 1, 4: 1, 5: 2, 6: 31, 7: 1}
Counts of misclassified original labels: {0: 29808, 1: 137}
Test : Accuracy: 0.7613 Recall : 0.9962 FDR: 0.4538  F1-score: 0.7056  
 K: 3 K-SIGMA Threshold : ---thr 0.1908
Counter({1: 65565, 0: 59867}) Counter({0: 89417, 1: 36015})
Counts of  labels: {0: 89417, 1: 35978, 2: 1, 3: 1, 4: 1, 5: 2, 6: 31, 7: 1}
Counts of misclassified original labels: {0: 29715, 1: 165}
Test : Accuracy: 0.7618 Recall : 0.9954 FDR: 0.4532  F1-score: 0.7058  
model AE is saved in best_models/AE_f1_0.71_recall_1.00_.pth
Train : time 174.78 s Epoch 2
Train Loss: 0.0026
--- Running Evaluation for Epoch 2 lr =5e-05 wd 0.0001 ---
-----------mse_loss mean :  0.00

In [None]:

# Variational AutoEncoder (VAE)
class VAE(nn.Module):
    """
    Encoder: (89-64-64-32 for mu and log_var)
    Decoder: (32-64-64-89)
    return x_recon, mu, logvar
    """
    def __init__(self,input_dim=89):
        super(VAE, self).__init__()
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, 64),
            nn.ReLU(),
            nn.Linear(64, 64),
            nn.ReLU()
        )
        self.fc_mu = nn.Linear(64, 32)
        self.fc_logvar = nn.Linear(64, 32)
        self.decoder = nn.Sequential(
            nn.Linear(32, 64),
            nn.ReLU(),
            nn.Linear(64, 64),
            nn.ReLU(),
            nn.Linear(64, input_dim),
            nn.Sigmoid()
                    )

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def forward(self, x):
        h = self.encoder(x)
        mu = self.fc_mu(h)
        logvar = self.fc_logvar(h)
        z = self.reparameterize(mu, logvar)
        x_recon = self.decoder(z)
        return x_recon, mu, logvar

    

In [None]:
VAE_model = VAE(input_dim=76)
train_eval(VAE_model,train_dataloader,val_dataloader,test_dataloader,shuffle_files=False,num_epochs=6,eval_epoch=1,criterion_method="mse",learning_rates=[1e-2,1e-3,1e-4],weight_decays=[1e-4])



Train : time 31.80 s Epoch 1
Train Loss: 0.0971
--- Running Evaluation for Epoch 1 lr =0.0005 wd 0 ---
-----------mse_loss mean :  5.7442 std: 1.6715
Val: Accuracy: 0.9991  
---thr 74.16
Counter({0: 246158, 1: 28}) Counter({0: 209532, 1: 36654})
Counts of  labels: {0: 209532, 1: 36432, 2: 49, 3: 30, 4: 49, 5: 39, 6: 31, 7: 24}
Counts of misclassified original labels: {0: 28, 1: 36432, 2: 49, 3: 30, 4: 49, 5: 39, 6: 31, 7: 24}
Test : Accuracy: 0.8510 Recall : 0.0000 FDR: 1.0000  F1-score: 0.0000  
---thr 37.08
Counter({0: 246044, 1: 142}) Counter({0: 209532, 1: 36654})
Counts of  labels: {0: 209532, 1: 36432, 2: 49, 3: 30, 4: 49, 5: 39, 6: 31, 7: 24}
Counts of misclassified original labels: {0: 141, 1: 36432, 2: 49, 3: 30, 4: 49, 5: 38, 6: 31, 7: 24}
Test : Accuracy: 0.8505 Recall : 0.0000 FDR: 0.9930  F1-score: 0.0001  
---thr 14.83
Counter({0: 178037, 1: 68149}) Counter({0: 209532, 1: 36654})
Counts of  labels: {0: 209532, 1: 36432, 2: 49, 3: 30, 4: 49, 5: 39, 6: 31, 7: 24}
Counts of

KeyboardInterrupt: 

In [None]:

class AAE_Encoder(nn.Module):
    def __init__(self,input_dim=76):
        """
        Encoder(Generator):(89-16-4-2)
        """
        super(AAE_Encoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, 16),
            nn.LeakyReLU(0.2),
            nn.Linear(16, 4),
            nn.LeakyReLU(0.2),
            nn.Linear(4, 2))
    def forward(self, x):
        return self.encoder(x)
class AAE_Decoder(nn.Module):
    def __init__(self,input_dim=76):
        super(AAE_Decoder, self).__init__()
        self.decoder = nn.Sequential(
            nn.Linear(2, 4),
            nn.LeakyReLU(),
            nn.Linear(4, 16),
            nn.LeakyReLU(),
            nn.Linear(16, input_dim),
            nn.Sigmoid()
        )
    def forward(self, x):
        return self.decoder(x)
class AAE_Discriminator(nn.Module):
    def __init__(self):
        super(AAE_Discriminator, self).__init__()
        # corrected to 2-16-4-1
        self.discriminator = nn.Sequential(
            nn.Linear(2, 16),
            nn.LeakyReLU(),
            nn.Linear(16, 4),
            nn.LeakyReLU(),
            nn.Linear(4, 1), 
            nn.Sigmoid()
        )    
    def forward(self, x):
        return self.discriminator(x)
 
class AdversarialAutoencoder(nn.Module):
    def __init__(self):
        super(AdversarialAutoencoder, self).__init__()
        self.encoder = AAE_Encoder()
        self.decoder = AAE_Decoder()
        self.discriminator = AAE_Discriminator()
    def forward(self, x):
        fake_z = self.encoder(x)
        x_recon = self.decoder(fake_z)
        return fake_z,x_recon


In [None]:


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
aae_model = AdversarialAutoencoder().to(device)
device = next(aae_model.parameters()).device
print(f"Model device (Method 1): {device}")
print(next(aae_model.decoder.parameters()).device)
print(aae_model._get_name())
train_eval(aae_model,AE_train_dataloader,AE_val_dataloader,AE_test_dataloader)

# for epoch in range(num_epochs):
#     # aae_encoder.train()
#     # aae_decoder.train()
#     # aae_discriminator.train()
#     aae_model.train()
#     if shuffle_files:
#         sys_rand = SystemRandom()
#         sys_rand.shuffle(AE_dataset.file_order_indices)
#     for sequences,_ in AE_dataloader:
#         sequences=sequences.squeeze().to(device)
#         # 1) reconstruction + generator loss
#         optimizer_G.zero_grad()
#         fake_z = aae_encoder(sequences)
#         decoded_seq = aae_decoder(fake_z)
#         G_loss = 0.001*adversarial_loss(aae_discriminator(fake_z),  torch.ones(sequences.size(0), 2,device=device)) + 0.999*reconstruction_loss(decoded_seq, sequences)
#         G_loss.backward()
#         optimizer_G.step()
#         # 2) discriminator loss
#         optimizer_D.zero_grad()
#         real_loss = adversarial_loss(aae_discriminator(torch.randn(sequences.size(0), 2, device=device)),  torch.ones(sequences.size(0), 2,device=device))
#         fake_loss = adversarial_loss(aae_discriminator(fake_z.detach()),  torch.zeros(sequences.size(0), 2,device=device))
#         D_loss = 0.5*(real_loss + fake_loss)
#         D_loss.backward()
#         optimizer_D.step()
#     # print loss
#     print(
#             "[Epoch %d/%d] [G loss: %f] [D loss: %f]"
#             % (epoch, num_epochs, G_loss.item(), D_loss.item())
#          )

Model device (Method 1): cuda:0
cuda:0
AdversarialAutoencoder

Train : time 253.4021 Epoch 0
Generator Loss: 0.8527 Discriminator Loss: 61.5562

--- Running Evaluation for Epoch 1 lr =0.01 wd 0.0001 ---
Computed Threshold: 0.0011
Val: Accuracy: 0.9964  
Test : Accuracy: 0.8894 Recall : 0.8443 FDR: 0.2954  F1-score: 0.7681  
Train : time 228.5464 Epoch 1
Generator Loss: 0.5067 Discriminator Loss: 60.1043

--- Running Evaluation for Epoch 2 lr =0.01 wd 0.0001 ---
Computed Threshold: 0.0011
Val: Accuracy: 0.9939  
Test : Accuracy: 0.8894 Recall : 0.8443 FDR: 0.2954  F1-score: 0.7682  

Train : time 210.8485 Epoch 0
Generator Loss: 1.1983 Discriminator Loss: 299.7112

--- Running Evaluation for Epoch 1 lr =0.01 wd 1e-05 ---
Computed Threshold: 0.0014
Val: Accuracy: 0.9957  
Test : Accuracy: 0.8894 Recall : 0.8443 FDR: 0.2954  F1-score: 0.7681  
Train : time 207.1711 Epoch 1
Generator Loss: 0.5891 Discriminator Loss: 57.8186

--- Running Evaluation for Epoch 2 lr =0.01 wd 1e-05 ---
Computed

In [None]:
AAE_model = AdversarialAutoencoder()
train_eval(AAE_model,train_dataloader,val_dataloader,test_dataloader,shuffle_files=False,num_epochs=6,eval_epoch=1,criterion_method="mse",learning_rates=[1e-2,1e-3,1e-4],weight_decays=[1e-4])


In [None]:
# GRU-VAE
class GRUVAE(nn.Module):
    """
    Gated Recurrent Unit : num_layers=2, hidden_size=256, dropout=0.01,window size (seq_len)= 40
    """
    def __init__(self, input_dim=89, hidden_dim=256, latent_dim=32, num_layers=2, dropout=0.01):
        super(GRUVAE, self).__init__()
        self.encoder_gru = nn.GRU(
            input_size=input_dim,
            hidden_size=hidden_dim,
            num_layers=num_layers,
            dropout=dropout,
            batch_first=True
        )
        self.fc_mu = nn.Linear(hidden_dim, latent_dim)
        self.fc_logvar = nn.Linear(hidden_dim, latent_dim)
        self.fc_z_to_hidden = nn.Linear(latent_dim, hidden_dim)
        self.decoder_gru = nn.GRU(
            input_size=input_dim,
            hidden_size=hidden_dim,
            num_layers=num_layers,
            dropout=dropout,
            batch_first=True
        )
        self.fc_out = nn.Linear(hidden_dim, input_dim)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def forward(self, x):
        # x shape: [batch_size, seq_len, input_dim=89]
        _, hidden = self.encoder_gru(x) 
        h = hidden[-1]  # [batch_size, hidden_dim]
        mu = self.fc_mu(h)  
        logvar = self.fc_logvar(h)  
        z = self.reparameterize(mu, logvar)  # [batch_size, latent_dim]
        # repeat and feed latent z as input trick
        h0 = self.fc_z_to_hidden(z).unsqueeze(0).repeat(self.encoder_gru.num_layers, 1, 1)  # [num_layers, batch_size, hidden_dim]
        # Initialize decoder input with zeros 
        decoder_input = torch.zeros_like(x)
        output, _ = self.decoder_gru(decoder_input, h0)  # [batch_size, seq_len, hidden_dim]
        x_recon = self.fc_out(output)  # [batch_size, seq_len, input_dim]
        return nn.Sigmoid(x_recon), mu, logvar


In [None]:
loaded_scalers = load_scalers("fitted_scalers")
RNN_train_dataset=ModbusFlowStream( 
    shuffle=False,chunk_size=1,batch_size=64,csv_files=train_files,
    scalers=loaded_scalers['network-wide']['min_max_scalers'],window_size=10
)
RNN_train_dataloader=DataLoader(RNN_train_dataset,batch_size=1,shuffle=False)\

RNN_val_dataset=ModbusFlowStream( 
    shuffle=False,chunk_size=1,batch_size=64,csv_files=val_files,
    scalers=loaded_scalers['network-wide']['min_max_scalers'],window_size=10
)
RNN_val_dataloader=DataLoader(RNN_val_dataset,batch_size=1,shuffle=False)


RNN_test_dataset=ModbusFlowStream( 
    shuffle=False,chunk_size=1,batch_size=64,csv_files=test_files,
    scalers=loaded_scalers['network-wide']['min_max_scalers'],window_size=10
)
RNN_test_dataloader=DataLoader(RNN_test_dataset,batch_size=1,shuffle=False)





Successfully loaded scalers for 'network-wide'


In [None]:
GRU_VAE_model = GRUVAE()
train_eval(GRU_VAE_model,RNN_train_dataloader,RNN_val_dataloader,RNN_test_dataloader,shuffle_files=False,num_epochs=6,eval_epoch=1,criterion_method="mse",learning_rates=[1e-2,1e-3,1e-4],weight_decays=[1e-4])



Train : time 219.29 s Epoch 1
Train Loss: 4.8276
--- Running Evaluation for Epoch 1 lr =5e-05 wd 1e-05 ---


  return F.mse_loss(input, target, reduction=self.reduction)


RuntimeError: The size of tensor a (64) must match the size of tensor b (9) at non-singleton dimension 0

In [None]:
for epoch in range(3):
    time_1 = time.time()
    train_loss = 0
    GRU_VAE_model.train()
    if shuffle_files:
        sys_rand = SystemRandom()
        sys_rand.shuffle(RNN_dataset.file_order_indices)
    for sequences, _ in RNN_dataloadder:
        sequences = sequences.squeeze().to(device)
        GRU_VAE_optimizer.zero_grad()
        recon, mu, logvar = GRU_VAE_model(sequences)
        loss = vae_loss_function(recon, sequences, mu, logvar)/sequences.size(0)
        loss.backward()
        GRU_VAE_optimizer.step()
        train_loss += loss.item()
    print("time",time.time()-time_1,f"Epoch {epoch}, Train Loss: {train_loss/len(RNN_dataloadder)}")


time 103.30588173866272 Epoch 0, Train Loss: 127700652.60516566
time 95.60702395439148 Epoch 1, Train Loss: 70232281.69612122
time 94.19640827178955 Epoch 2, Train Loss: 3457994.1813964844


### Federated learning

In [None]:
# ==============================================================================
# 1. IMPORT DEPENDENCIES
# ==============================================================================

import os
import torch
import flwr as fl
import numpy as np
import pandas as pd
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score, accuracy_score
from collections import OrderedDict
from typing import Dict, List, Tuple
import matplotlib.pyplot as plt
import seaborn as sns

# Suppress warning messages for a cleaner output
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
# Set a seed for reproducibility
SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed(SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Training on {DEVICE}")

# ==============================================================================
# 2. CONFIGURATION: TWEAK YOUR FEDERATED LEARNING EXPERIMENT
# ==============================================================================
class Config:
    """Configuration class for the federated learning experiment."""
    NUM_CLIENTS = 5
    NUM_ROUNDS = 10
    LOCAL_EPOCHS = 5
    BATCH_SIZE = 32
    lr = 0.001
    wd = 1e-4
    # Choose from "FED_AVG", "FED_PROX"
    STRATEGY = "FED_AVG"
    # Proximal term for FedProx, only used if STRATEGY is "FED_PROX"
    PROXIMAL_MU = 0.1

# Instantiate the configuration
cfg = Config()


# ==============================================================================
# 3. DATA PREPARATION: DUMMY DATA AND CUSTOM DATASET
# ==============================================================================

ied1b_train_files=[col for col in modbus.dataset["benign_dataset_dir"] if col.find("ied1b")!=-1]
ied1a_train_files=[col for col in modbus.dataset["benign_dataset_dir"] if col.find("ied1a")!=-1]
ied4c_train_files=[col for col in modbus.dataset["benign_dataset_dir"] if col.find("ied4c")!=-1]
scada_train_files=[col for col in modbus.dataset["benign_dataset_dir"] if col.find("scada-hmi")!=-1]
cent_agent_train_files=[col for col in modbus.dataset["benign_dataset_dir"] if col.find("central-agent")!=-1]

CLIENT_DATA_MAPPING = {
    0: [ied1b_train_files],
    1: [ied1a_train_files],
    2: [ied4c_train_files],
    3: [scada_train_files],
    4: [cent_agent_train_files],
    5:[cent_agent_train_files,cent_agent_train_files],#eval-test
}

# --- Data Loading Function for Clients ---
def load_data(client_id: int, model_name: str):
    """Loads the data for a specific client based on the mapping."""
    file_list = CLIENT_DATA_MAPPING[client_id]
    if len(file_list)==1:#train

    if model_name!="GRUVAE":
        train_dataset=ModbusFlowStream( 
        shuffle=True,chunk_size=20,batch_size=64,csv_files=file_list,scalers=loaded_scalers['network-wide']['min_max_scalers'],window_size=1)
    else: #"AE" , "VAE" , "AAE"    
        train_dataset=ModbusFlowStream( 
        shuffle=False,chunk_size=20,batch_size=64,csv_files=file_list,scalers=loaded_scalers['network-wide']['min_max_scalers'],window_size=30)
    train_dataloader=DataLoader(train_dataset,batch_size=1,shuffle=False)
    return train_dataloader


# ==============================================================================
# 4. FEDERATED LEARNING CLIENT: FlowerClient
# ==============================================================================
class FlowerClient(fl.client.NumPyClient):
    """Flower client for training the AutoEncoder."""
    def __init__(self, cid, model, trainloader=None,val_loader=None,test_loader=None):
        self.cid = cid
        self.model = model
        self.train_dataloader = trainloader
        self.val_dataloader=val_loader
         self.test_dataloader=test_data_loader
    def get_parameters(self, config):
        """Return model parameters as a list of NumPy arrays."""
        return [val.cpu().numpy() for _, val in self.model.state_dict().items()]

    def set_parameters(self, parameters):
        """Update model parameters from a list of NumPy arrays."""
        params_dict = zip(self.model.state_dict().keys(), parameters)
        state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
        self.model.load_state_dict(state_dict, strict=True)

    def fit(self, parameters, config):
        """Train the model on local data."""
 
        self.set_parameters(parameters)
        # Add proximal term for FedProx
        proximal_term = 0.
        
        if cfg.STRATEGY == "FED_PROX":
            global_params = [torch.tensor(p, device=DEVICE) for p in parameters]
    
        device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
        model = self.model
        model.to(device)
        criterion = nn.MSELoss(reduction='sum').to(device)
        eval_criterion = nn.MSELoss(reduction='none').to(device)
        if model._get_name()=="AdversarialAutoencoder":
            adversarial_criterion= nn.BCELoss(reduction="sum")
            optimizer_D = optim.Adam(model.discriminator.parameters(), lr=cfg.lr, weight_decay=cfg.wd)
            optimizer_G =  optim.Adam(list(model.encoder.parameters()) + list(model.decoder.parameters()), lr=cfg.lr, weight_decay=cfg.wd)
        else:
            AE_optimizer = optim.Adam(model.parameters(), lr=cfg.lr,weight_decay=cfg.wd)
        print(f"\n==================client id={self.cid}  lr={self.lr}, wd={self.wd} ==================")
        model.apply(_init_weights)
        for epoch in range(cfg.LOCAL_EPOCHS):
            time_1 = time.time()
            model.train()
            train_loss = 0
            ## for AAE
            Discriminator_loss = 0
            if shuffle_files:
                sys_rand = SystemRandom()
                sys_rand.shuffle(self.train_dataloader.dataset.csv_files)
            for sequences, _ in self.train_dataloader:
                sequences=sequences.squeeze().to(device)
                if model._get_name()=="AdversarialAutoencoder":
                    target_ones= torch.ones(sequences.size(0), 1,device=device,dtype=torch.float)
                    target_zeros= torch.zeros(sequences.size(0), 1,device=device,dtype=torch.float)
                    random_latent = torch.randn(sequences.size(0), 2, device=device)
                    optimizer_G.zero_grad()
                    fake_z,decoded_seq = model(sequences)
                    G_loss = 0.001*adversarial_criterion(model.discriminator(fake_z),target_ones ) + 0.999*criterion(decoded_seq, sequences)
                    G_loss.backward()
                    optimizer_G.step()
                    # 2) discriminator loss
                    optimizer_D.zero_grad()
                    real_loss = adversarial_criterion(model.discriminator(random_latent), target_ones)
                    fake_loss = adversarial_criterion(model.discriminator(fake_z.detach()),  target_zeros)
                    D_loss =  0.001*0.5*(real_loss + fake_loss)
                    D_loss.backward()
                    optimizer_D.step()
                    train_loss+=G_loss.item()
                    Discriminator_loss+=D_loss.item()   
                else:
                    AE_optimizer.zero_grad()
                    if model._get_name()=="AE":
                        recon = model(sequences)
                        loss = criterion(recon, sequences) / sequences.size(0)
                    elif model._get_name()=="VAE" or model._get_name()=="GRUVAE":
                        recon, mu, logvar = model(sequences)
                        loss = vae_loss_function(recon, sequences, mu, logvar) /sequences.size(0)
                    loss.backward()
                    AE_optimizer.step()
                    train_loss += loss.item()
            print(f"Train : time {(time.time()-time_1):.2f} s",
            f"Epoch {epoch+1}")
            if model._get_name()=="AdversarialAutoencoder":
                print(f"Generator Loss: {train_loss / len(self.train_dataloader):.4f}",
                    f"Discriminator Loss: {Discriminator_loss / len(self.train_dataloader):.4f}")
            else:
                print(f"client id {self.cid} Train Loss: {train_loss / len(self.train_dataloader):.4f}")
        
        self.model = model
        return self.get_parameters(config={}), len(self.train_dataloader.dataset), {}

    def evaluate(self, parameters, config):
        """Evaluate the model on the local test set."""
        self.set_parameters(parameters)
        model = self.model()
        model.to(DEVICE)
        # Evaluate part
        if (epoch + 1) % eval_epoch == 0:
            model.eval() 
            all_val_losses = []
            all_val_labels = []
            print(f"--- Running Evaluation for Epoch {epoch+1} lr ={lr} wd {wd} ---")
            with torch.no_grad():
                for sequences, labels in self.val_dataloader:
                    sequences = sequences.squeeze().to(device)        
                    if model._get_name()=="AE":
                        recon = model(sequences)
                    elif model._get_name()=="VAE" or model._get_name()=="GRUVAE" :
                        recon, _, _ = model(sequences)
                    elif model._get_name()=="AdversarialAutoencoder":
                        _,recon= model(sequences)
                    val_loss = eval_criterion(recon, sequences)
                    if val_loss.dim() > 1:
                        val_loss = val_loss
                    else:
                        val_loss = val_loss.unsqueeze(dim=0)
                        labels = labels.unsqueeze(dim=0)
                    val_loss = val_loss.sum(dim=1)
                    all_val_losses.extend(val_loss.cpu().numpy())
                    all_val_labels.extend(labels.flatten().cpu().numpy())            
            threshold_1 = compute_threshold(all_val_losses,k=1)

            all_val_losses = np.array(all_val_losses).squeeze()  
            all_val_labels = np.array(all_val_labels).squeeze()  
            # If intrusion score > threshold, predict 1 (intrusion), else 0 (benign)
            # For FDR, get True Positives (TP) and False Positives (FP)
            print(all_val_losses.shape,all_val_labels.shape,predictions.shape)
            predictions = (all_val_losses > threshold_1).astype(int)
            accuracy = accuracy_score(all_val_labels, predictions)
            print(f"Val: Accuracy: {accuracy:.4f}  ")
            model.eval() 
            all_test_losses = []
            all_test_labels = []
            with torch.no_grad():
                for sequences, labels in self.test_dataloader:
                    sequences = sequences.squeeze().to(device)
                    labels = labels.squeeze().to(device)
                    if model._get_name()=="AE":
                        recon = model(sequences)
                    elif model._get_name()=="VAE"  or model._get_name()=="GRUVAE":
                        recon, mu, logvar = model(sequences)
                    elif model._get_name()=="AdversarialAutoencoder":
                        _,recon= model(sequences)
                    intrusion_scores = eval_criterion(recon, sequences)
                    if intrusion_scores.dim() > 1:
                        intrusion_scores = intrusion_scores
                    else:
                        intrusion_scores = intrusion_scores.unsqueeze(dim=0)
                        labels = labels.unsqueeze(dim=0)
                    intrusion_scores = intrusion_scores.sum(dim=1)
                    all_test_losses.extend(intrusion_scores.cpu().numpy())
                    all_test_labels.extend(labels.cpu().numpy())

            all_test_losses = np.array(all_test_losses)
            all_test_labels = np.array(all_test_labels)
            # for threshold in {threshold_1,threshold_2,threshold_3,threshold_10,threshold_100}:
            for threshold in {threshold_1}:

                print(f"---thr {threshold:.4}")
                predictions = (all_test_losses > threshold).astype(int)
                binary_test_labels = (all_test_labels != 0).astype(int)
                # --- Start of new code ---

                # Find the indices where the prediction was incorrect
                misclassified_indices = np.where(binary_test_labels != predictions)[0]

                # Get the original labels for those misclassified instances
                misclassified_original_labels = all_test_labels[misclassified_indices]
                # To get a summary count of which labels were misclassified
                print(f"Counts of  labels: {dict(sorted(Counter(all_test_labels).items()))}")
                print(f"Counts of misclassified original labels: {dict(sorted(Counter(misclassified_original_labels).items()))}")

                accuracy = accuracy_score(binary_test_labels, predictions)
                f1 = f1_score(binary_test_labels, predictions, zero_division=0)
                recall = recall_score(binary_test_labels, predictions,zero_division=0)
                _, fp, _, tp = confusion_matrix(binary_test_labels, predictions, labels=[0, 1]).ravel()
                # FDR = FP / (FP + TP) 
                if (fp + tp) == 0:
                    fdr = 0.0 
                else:
                    fdr = fp / (fp + tp)
                print(f"Test : Accuracy: {accuracy:.4f} Recall : {recall:.4f} FDR: {fdr:.4f}  F1-score: {f1:.4f} ")
                return np.mean(all_test_losses), len(self.testloader.dataset), {"accuracy": accuracy,"recall":Recall,"FDR":FDR, "f1_score": f1}


# ==============================================================================
# 6. SERVER-SIDE LOGIC AND SIMULATION START
# ==============================================================================
def client_fn(cid: str) -> FlowerClient:
    """Create a Flower client instance for a given client ID."""
    client_id = int(cid)
    # Load model and data for this client
    model = AutoEncoder().to(DEVICE)
    trainloader, testloader = load_data(client_id,model._get_name())
    return FlowerClient(cid, model, trainloader, testloader).to_client()

def weighted_average(metrics: List[Tuple[int, Dict[str, float]]]) -> Dict[str, float]:
    """Aggregate evaluation metrics from all clients."""
    accuracies = [num_examples * m["accuracy"] for num_examples, m in metrics]
    f1_scores = [num_examples * m["f1_score"] for num_examples, m in metrics]
    examples = [num_examples for num_examples, _ in metrics]

    # Aggregate and return custom metrics
    return {
        "accuracy": sum(accuracies) / sum(examples),
        "f1_score": sum(f1_scores) / sum(examples),
    }