### Download and make the dataset ready in Kaggle 


In [4]:

# ## uncomment if The zip file of the dataset isn't downloaded,extraced 
# !pip install gdown
# Copy the link. The file ID is the long string of characters between d/ and /view.
#For example, in the URL https://drive.google.com/file/d/1aBcDeFgHiJkLmNoPqRsTuVwXyZ/view?usp=sharing, 
#the file ID is 1aBcDeFgHiJkLmNoPqRsTuVwXyZ
# !mkdir /kaggle/tmp
# !gdown  1pzXpA5Cz0DJmjRsLxlqRNnJq-kOUvojb -O /kaggle/tmp/Labeled_CICMODBUS2023.zip
# !unzip /kaggle/tmp/Labeled_CICMODBUS2023.zip -d /kaggle/working/

# # ## uncomment if the python modules (modbus.py,utils.py ,...) not cloned  and added to the path 

# !git clone https://github.com/hamid-rd/FLBased-ICS-NIDS.git
# import sys
# # Add the repository folder to the Python path
# repo_path = '/kaggle/working/FLBased-ICS-NIDS'
# repo_input_path = '/kaggle/input/training/FLBased-ICS-NIDS'
# dataset_path = '/kaggle/input/training/'

# for path in {repo_path,repo_input_path,dataset_path}:
#     if path not in sys.path:
#         sys.path.append(path)


In [1]:
# To test if every thing is okay (modbus.py class and correct number of founded csv files )
from modbus import ModbusDataset,ModbusFlowStream

# dataset_directory = "/kaggle/working/ModbusDataset" 
# dataset_directory = "/kaggle/input/training/ModbusDataset" 
dataset_directory = "dataset" 

modbus = ModbusDataset(dataset_directory,"ready")
modbus.summary_print()

# Don't forget to save version in kaggle (to save outputs written on the disk (/kaggle/working/))  

 The CIC Modbus Dataset contains network (pcap) captures and attack logs from a simulated substation network.
                The dataset is categorized into two groups: an attack dataset and a benign dataset
                The attack dataset includes network traffic captures that simulate various types of Modbus protocol attacks in a substation environment.
                The attacks are reconnaissance, query flooding, loading payloads, delay response, modify length parameters, false data injection, stacking Modbus frames, brute force write and baseline replay.
                These attacks are based of some techniques in the MITRE ICS ATT&CK framework.
                On the other hand, the benign dataset consists of normal network traffic captures representing legitimate Modbus communication within the substation network.
                The purpose of this dataset is to facilitate research, analysis, and development of intrusion detection systems, anomaly detection algorithms and

### Unsupervised Autoencoder training  

In [2]:
import torch 
import torch.nn as nn
import torch.nn.functional as F
import numpy as np # For standard deviation calculation
from modbus import ModbusDataset,ModbusFlowStream
from sklearn.metrics import accuracy_score, f1_score, confusion_matrix,recall_score
import torch.optim as optim
from torch.utils.data import DataLoader
import time
from utils import load_scalers
from random import SystemRandom
from sklearn.model_selection import train_test_split
import itertools
import torch.nn.init as init


def compute_threshold(mse_values,k=1):

    """
    K-SIGMA
    Computes the anomaly detection threshold (for marking sample as Intrusion if the IS was greater )
    based on the mean and standard deviation of Mean Squared Error (MSE) values.
    Formula: thr = mean(MSE) + std(MSE)
    Args:
    mse_values (torch.Tensor or list/np.array): A tensor or list of MSE values

                            obtained from the validation set.
    Returns:
    float: The calculated threshold.
    float: The calculated std.

    """
    if not isinstance(mse_values, torch.Tensor):
        mse_values = torch.tensor(mse_values, dtype=torch.float32)
    if mse_values.numel() == 0:
        return 0.0
    mean_mse = torch.mean(mse_values)
    std_mse = torch.std(mse_values)
    print("-----------mse_loss mean : ",f"{mean_mse.item():.4f}","std:",f"{std_mse.item():.4f}")
    threshold = mean_mse + k*std_mse
    return threshold.item(),std_mse.item()

def vae_loss_function(recon_x, x, mu, logvar,beta =1):
    """
    VAE loss function.
    """
    BCE = nn.functional.mse_loss(recon_x, x, reduction='sum')
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return (BCE + beta*KLD)

def _init_weights( module):
    ## for one layer apply Xavier Initialization
    if isinstance(module, nn.Linear):
        init.xavier_normal_(module.weight)
        if module.bias is not None:
            init.zeros_(module.bias)
    return module


In [3]:
# dataset_directory = "/kaggle/input/training/ModbusDataset" # change this to the folder contain benign and attack subdirs
dataset_directory = "dataset" 
modbus = ModbusDataset(dataset_directory,"ready")
modbus.summary_print()

 The CIC Modbus Dataset contains network (pcap) captures and attack logs from a simulated substation network.
                The dataset is categorized into two groups: an attack dataset and a benign dataset
                The attack dataset includes network traffic captures that simulate various types of Modbus protocol attacks in a substation environment.
                The attacks are reconnaissance, query flooding, loading payloads, delay response, modify length parameters, false data injection, stacking Modbus frames, brute force write and baseline replay.
                These attacks are based of some techniques in the MITRE ICS ATT&CK framework.
                On the other hand, the benign dataset consists of normal network traffic captures representing legitimate Modbus communication within the substation network.
                The purpose of this dataset is to facilitate research, analysis, and development of intrusion detection systems, anomaly detection algorithms and

In [4]:

# AutoEncoder (AE)
class AE(nn.Module):
    """
    Encoder: (89-64-32)
    Decoder: (32-64-89)
    """
    def __init__(self,input_dim=89):
        super(AE, self).__init__()
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, 64),
            nn.ReLU(),
            nn.Linear(64, 32),
            nn.ReLU(),
        )
        self.decoder = nn.Sequential(
            nn.Linear(32, 64),
            nn.ReLU(),
            nn.Linear(64, input_dim),
            nn.Sigmoid()
        )

    def forward(self, x):
        z = self.encoder(x)
        x_recon = self.decoder(z)
        return x_recon


# Variational AutoEncoder (VAE)
class VAE(nn.Module):
    """
    Encoder: (89-64-64-32 for mu and log_var)
    Decoder: (32-64-64-89)
    return x_recon, mu, logvar
    """
    def __init__(self,input_dim=89):
        super(VAE, self).__init__()
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, 64),
            nn.ReLU(),
            nn.Linear(64, 64),
            nn.ReLU()
        )
        self.fc_mu = nn.Linear(64, 32)
        self.fc_logvar = nn.Linear(64, 32)
        self.decoder = nn.Sequential(
            nn.Linear(32, 64),
            nn.ReLU(),
            nn.Linear(64, 64),
            nn.ReLU(),
            nn.Linear(64, input_dim),
            nn.Sigmoid()
                    )

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def forward(self, x):
        h = self.encoder(x)
        mu = self.fc_mu(h)
        logvar = self.fc_logvar(h)
        z = self.reparameterize(mu, logvar)
        x_recon = self.decoder(z)
        return x_recon, mu, logvar

    
class AAE_Encoder(nn.Module):
    def __init__(self,input_dim=76):
        """
        Encoder(Generator):(89-16-4-2)
        """
        super(AAE_Encoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, 16),
            nn.LeakyReLU(0.2),
            nn.Linear(16, 4),
            nn.LeakyReLU(0.2),
            nn.Linear(4, 2))
    def forward(self, x):
        return self.encoder(x)
class AAE_Decoder(nn.Module):
    def __init__(self,input_dim=76):
        super(AAE_Decoder, self).__init__()
        self.decoder = nn.Sequential(
            nn.Linear(2, 4),
            nn.LeakyReLU(),
            nn.Linear(4, 16),
            nn.LeakyReLU(),
            nn.Linear(16, input_dim),
            nn.Sigmoid()
        )
    def forward(self, x):
        return self.decoder(x)
class AAE_Discriminator(nn.Module):
    def __init__(self):
        super(AAE_Discriminator, self).__init__()
        # corrected to 2-16-4-1
        self.discriminator = nn.Sequential(
            nn.Linear(2, 16),
            nn.LeakyReLU(),
            nn.Linear(16, 4),
            nn.LeakyReLU(),
            nn.Linear(4, 1), 
            nn.Sigmoid()
        )    
    def forward(self, x):
        return self.discriminator(x)
 
class AdversarialAutoencoder(nn.Module):
    def __init__(self):
        super(AdversarialAutoencoder, self).__init__()
        self.encoder = AAE_Encoder()
        self.decoder = AAE_Decoder()
        self.discriminator = AAE_Discriminator()
    def forward(self, x):
        fake_z = self.encoder(x)
        x_recon = self.decoder(fake_z)
        return fake_z,x_recon


In [8]:
from collections import Counter

def train_eval(model,train_dataloader,val_dataloader,test_dataloader,learning_rates= [5e-6,1e-7,5e-5,1e-5,1e-6],
               weight_decays=[1e-5,1e-4,1e-7],shuffle_files=True,num_epochs=20,eval_epoch=4,criterion_method="mse", k_range=[1,3],train_model=True):
    device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model=model.to(device)
    if criterion_method=="bce":
        criterion = nn.BCELoss(reduction='sum').to(device)
        eval_criterion = nn.BCELoss(reduction='none').to(device)
    else: #mse
        criterion = nn.MSELoss(reduction='sum').to(device)
        eval_criterion = nn.MSELoss(reduction='none').to(device)
    best_f1=0 #to save best version of the model during test
    best_recall=0 #to save best version of the model during test

    for lr, wd in itertools.product(learning_rates, weight_decays):
        if model._get_name()=="AdversarialAutoencoder":
            adversarial_criterion= nn.BCELoss(reduction="sum")
            optimizer_D = optim.SGD(model.discriminator.parameters(), lr=lr, weight_decay=wd)
            optimizer_G =  optim.SGD(list(model.encoder.parameters()) + list(model.decoder.parameters()), lr=lr, weight_decay=wd)
        else:
            AE_optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=wd)
            ### new code
            # AE_optimizer = optim.SGD(model.parameters(), lr=lr, weight_decay=wd)

        print(f"\n==================  lr={lr}, wd={wd} ==================")
        if train_model==True:
            model.apply(_init_weights)
        for epoch in range(num_epochs):
            if train_model==True:
                time_1 = time.time()
                model.train()
                train_loss = 0
                ## for AAE
                Discriminator_loss = 0
                if shuffle_files:
                    sys_rand = SystemRandom()
                    sys_rand.shuffle(train_dataloader.dataset.csv_files)
                for sequences, labels in train_dataloader:
                    sequences=sequences.squeeze().to(device)
                    if labels.sum()!=0:
                        continue
                    if model._get_name()=="AdversarialAutoencoder":
                        target_ones= torch.ones(sequences.size(0), 1,device=device,dtype=torch.float)
                        target_zeros= torch.zeros(sequences.size(0), 1,device=device,dtype=torch.float)
                        random_latent = torch.randn(sequences.size(0), 2, device=device)
                        optimizer_G.zero_grad()
                        fake_z,decoded_seq = model(sequences)
                        G_loss = 0.001*adversarial_criterion(model.discriminator(fake_z),target_ones ) + 0.999*criterion(decoded_seq, sequences)
                        G_loss.backward()
                        optimizer_G.step()
                        # 2) discriminator loss
                        optimizer_D.zero_grad()
                        real_loss = adversarial_criterion(model.discriminator(random_latent), target_ones)
                        fake_loss = adversarial_criterion(model.discriminator(fake_z.detach()),  target_zeros)
                        D_loss = 0.5*(real_loss + fake_loss)
                        D_loss.backward()
                        optimizer_D.step()
                        train_loss+=G_loss.item()
                        Discriminator_loss+=D_loss.item()   
                    else:
                        AE_optimizer.zero_grad()
                        if model._get_name()=="AE":
                            recon = model(sequences)
                            loss = criterion(recon, sequences) / sequences.size(0)
                        elif model._get_name()=="VAE" or model._get_name()=="GRUVAE":
                            recon, mu, logvar = model(sequences)
                            loss = vae_loss_function(recon, sequences, mu, logvar) /sequences.size(0)
                        loss.backward()
                        AE_optimizer.step()
                        train_loss += loss.item()
                print(f"Train : time {(time.time()-time_1):.2f} s",
                f"Epoch {epoch+1}")
                if model._get_name()=="AdversarialAutoencoder":
                    print(f"Generator Loss: {train_loss / len(train_dataloader):.4f}",
                        f"Discriminator Loss: {Discriminator_loss / len(train_dataloader):.4f}")
                else:
                    print(f"Train Loss: {train_loss / len(train_dataloader):.4f}")
            # Evaluate part
            if (epoch + 1) % eval_epoch == 0:
                model.eval() 
                all_val_losses = []
                all_val_labels = []
                print(f"--- Running Evaluation for Epoch {epoch+1} lr ={lr} wd {wd} ---")
                with torch.no_grad():
                    for sequences, labels in val_dataloader:
                        sequences = sequences.squeeze().to(device) 
                        if labels.sum()!=0:
                            continue
                        if criterion_method=="bce":
                            ## may test features be greater than 1 after scaling 
                            sequences=torch.clamp(sequences, min=0.0, max=1.0)      
                        if model._get_name()=="AE":
                            recon = model(sequences)
                        elif model._get_name()=="VAE" or model._get_name()=="GRUVAE" :
                            recon, _, _ = model(sequences)
                        elif model._get_name()=="AdversarialAutoencoder":
                            _,recon= model(sequences)
                        val_loss = eval_criterion(recon, sequences)
                        if val_loss.dim() > 1:
                            val_loss = val_loss
                        else:
                            val_loss = val_loss.unsqueeze(dim=0)
                            labels = labels.unsqueeze(dim=0)
                        if val_loss.dim()==3:
                            ##GRU : mean of window
                            val_loss = val_loss.mean(dim=1)
                        val_loss = val_loss.sum(dim=1)
                        all_val_losses.extend(val_loss.cpu().numpy())
                        all_val_labels.extend(labels.flatten().cpu().numpy())     
                threshold_1,std_mse = compute_threshold(all_val_losses,k=0)

                all_val_losses = np.array(all_val_losses).squeeze()  
                all_val_labels = np.array(all_val_labels).squeeze()  
                # If intrusion score > threshold, predict 1 (intrusion), else 0 (benign)
                # For FDR, get True Positives (TP) and False Positives (FP)
                
                predictions = (all_val_losses > threshold_1).astype(int)

                accuracy = accuracy_score(all_val_labels, predictions)
                print(f"Val: Accuracy: {accuracy:.4f}  ")
                model.eval() 
                all_test_losses = []
                all_test_labels = []
                with torch.no_grad():
                    for sequences, labels in test_dataloader:
                        sequences = sequences.squeeze().to(device)
                        labels = labels.squeeze().to(device)
                        if criterion_method=="bce":
                            ## may test features be greater than 1 after scaling 
                            sequences=torch.clamp(sequences, min=0.0, max=1.0)
                        if model._get_name()=="AE":
                            recon = model(sequences)
                        elif model._get_name()=="VAE"  or model._get_name()=="GRUVAE":
                            recon, mu, logvar = model(sequences)
                        elif model._get_name()=="AdversarialAutoencoder":
                            _,recon= model(sequences)

                        intrusion_scores = eval_criterion(recon, sequences)
                        if intrusion_scores.dim() > 1:
                            intrusion_scores = intrusion_scores
                        else:
                            intrusion_scores = intrusion_scores.unsqueeze(dim=0)
                            labels = labels.unsqueeze(dim=0)
                        if intrusion_scores.dim()==3:
                            ##GRU : mean of window
                            intrusion_scores = intrusion_scores.mean(dim=1)
                        intrusion_scores = intrusion_scores.sum(dim=1)
                        all_test_losses.extend(intrusion_scores.cpu().numpy())
                        all_test_labels.extend(labels.cpu().numpy())

                all_test_losses = np.array(all_test_losses)
                all_test_labels = np.array(all_test_labels)
                temp_best_recall =best_recall
                temp_best_f1 =best_f1

                for k in k_range:
                    threshold=threshold_1+k*std_mse
                    print(f" K: {k} K-SIGMA Threshold : ---thr {threshold:.4}")
                    predictions = (all_test_losses > threshold).astype(int)
                    binary_test_labels = (all_test_labels != 0).astype(int)

                    # Find the indices where the prediction was incorrect
                    misclassified_indices = np.where(binary_test_labels != predictions)[0]

                    # Get the original labels for those misclassified instances
                    misclassified_original_labels = all_test_labels[misclassified_indices]

                    # To get a summary count of which labels were misclassified
                    print(Counter(predictions),Counter(binary_test_labels))
                    print(f"Counts of  labels: {dict(sorted(Counter(all_test_labels).items()))}")
                    print(f"Counts of misclassified original labels: {dict(sorted(Counter(misclassified_original_labels).items()))}")
                    accuracy = accuracy_score(binary_test_labels, predictions)
                    f1 = f1_score(binary_test_labels, predictions, zero_division=0)
                    recall = recall_score(binary_test_labels, predictions,zero_division=0)
                    _, fp, _, tp = confusion_matrix(binary_test_labels, predictions, labels=[0, 1]).ravel()
                    # FDR = FP / (FP + TP) 
                    if (fp + tp) == 0:
                        fdr = 0.0 
                    else:
                        fdr = fp / (fp + tp)
                    print(f"Test : Accuracy: {accuracy:.4f} Recall : {recall:.4f} FDR: {fdr:.4f}  F1-score: {f1:.4f}  ")
                    !mkdir best_models -p
                    if f1>best_f1 :
                        best_f1=f1
                    if recall>best_recall:
                        best_recall=recall
                if (best_recall>temp_best_recall or best_f1 > temp_best_f1):
                    if train_model==True:
                        save_path ="best_models/"+model._get_name()+"_f1_"+f"{best_f1:.2f}" +"_recall_"+f"{best_recall:.2f}" +"_.pth"
                        torch.save(model.state_dict(),save_path)
                        print("model",model._get_name(),"is saved in" ,save_path )


#### Centralized : external scenario -> ied1a node

In [9]:
# train_files=[col for col in modbus.dataset["benign_dataset_dir"] if col.find("network-wide")!=-1]
train_files=[col for col in modbus.dataset["benign_dataset_dir"] if col.find("network-wide")!=-1][:]
val_files = [col for col in modbus.dataset["benign_dataset_dir"] if col.find("ied1a")!=-1][:]
test_files=[col for col in modbus.dataset["attack_dataset_dir"]["external"] if col.find("ied1a")!=-1]
sys_rand = SystemRandom()

sys_rand.shuffle(train_files)
sys_rand.shuffle(val_files)
sys_rand.shuffle(test_files)


print("ied1b comp ied attack ->\n test: ",len(test_files),test_files)
print("----------- Network-wide number of csv files -> \n ----------- train :",len(train_files),train_files,"\n -------- valid:",len(val_files),val_files)

ied1b comp ied attack ->
 test:  1 ['dataset/ModbusDataset/attack/external/ied1a/ied1a-network-capture/ready/veth4edc015-0-labeled.csv']
----------- Network-wide number of csv files -> 
 ----------- train : 19 ['dataset/ModbusDataset/benign/network-wide-pcap-capture/network-wide/ready/network-wide-normal-21-labeled.csv', 'dataset/ModbusDataset/benign/network-wide-pcap-capture/network-wide/ready/network-wide-normal-16-labeled.csv', 'dataset/ModbusDataset/benign/network-wide-pcap-capture/network-wide/ready/network-wide-normal-27-labeled.csv', 'dataset/ModbusDataset/benign/network-wide-pcap-capture/network-wide/ready/network-wide-normal-18-labeled.csv', 'dataset/ModbusDataset/benign/network-wide-pcap-capture/network-wide/ready/network-wide-normal-25-labeled.csv', 'dataset/ModbusDataset/benign/network-wide-pcap-capture/network-wide/ready/network-wide-normal-28-labeled.csv', 'dataset/ModbusDataset/benign/network-wide-pcap-capture/network-wide/ready/network-wide-normal-31-labeled.csv', 'data

In [10]:
### Try The Copy-on-Write (CoW) technique, share the same single copy of the dataset in memory with multiple forked workers from the main process
# Ensure to have enough memory for saving large tensors in the ram 
###### else use chunk_size =1 and read the files iteratively

use_cow=False
window_size=1
loaded_scalers=load_scalers('fitted_scalers')

Successfully loaded scalers for 'network-wide'


In [7]:


# This cell Initializes and returns train, validation, and test dataloaders.

# This function supports two strategies for data loading:
# 1. Copy-on-Write (use_cow=True): Loads the entire dataset into RAM. This is fast
#     but memory-intensive. It allows multiple worker processes to share the same
#     dataset copy in memory, which is efficient for multiprocessing.
# 2. Iterative (use_cow=False): Reads data from files in small chunks. This is
#     slower but uses significantly less memory, suitable for very large datasets
#     that don't fit in RAM.

#     train_files (list): List of file paths for the training dataset.
#     val_files (list): List of file paths for the validation dataset.
#     test_files (list): List of file paths for the test dataset.
#     window_size (int): The size of the sliding window for sequence data.
#     use_cow (bool, optional): If True, uses the Copy-on-Write strategy. 
#                                 Defaults to True.

#      return        (train_dataloader, val_dataloader, test_dataloader)

if use_cow==True:
    large_chunk_size = modbus.dataset["metadata"]["founded_files_num"]["total_dataset_num"]

    dataset_configs = {
        "train": {"files": train_files},
        "val": {"files": val_files},
        "test": {"files": test_files},
    }
    datasets = {}
    ae_datasets = {}

    print("Cow Processing datasets...")

    for name, config in dataset_configs.items():
        print(f"  - Creating '{name}' dataset...")
        
        # 1. Create the primary ModbusFlowStream dataset
        datasets[name] = ModbusFlowStream(
            shuffle=False,
            chunk_size=large_chunk_size,
            batch_size=1,
            csv_files=config["files"],
            scalers=loaded_scalers['network-wide']['min_max_scalers'],
            window_size=window_size
        )
        
        # 2. Call __getitem__(0) once to load the entire dataset chunk into memory
        datasets[name].__getitem__(0)
        
        # used for specific AE training/evaluation without re-reading files.
        ae_datasets[name] = ModbusFlowStream(
            shuffle=False,  # AE data is typically processed in order
            chunk_size=large_chunk_size,
            batch_size=1,
            csv_files=[],  # No CSV files needed as we copy the data directly
            scalers=None,   # Data is already scaled from the original dataset
            window_size=window_size
        )
        
        # 4. Manually copy the loaded data and properties to the AE dataset

        ae_datasets[name].current_chunk_data =  datasets[name].current_chunk_data
        ae_datasets[name].current_len_chunk_data =  datasets[name].current_len_chunk_data
        ae_datasets[name].current_chunk_labels =  datasets[name].current_chunk_labels
        ae_datasets[name].total_batches =  datasets[name].total_batches
        
        print(f"  - Finished '{name}' dataset.")
    train_dataloader=DataLoader(ae_datasets['train'],batch_size=64,shuffle=True,num_workers=4,persistent_workers=True,prefetch_factor=2,pin_memory=True)
    val_dataloader=DataLoader(ae_datasets['val'],batch_size=64,shuffle=False,num_workers=4,persistent_workers=True,prefetch_factor=2,pin_memory=True)
    test_dataloader=DataLoader(ae_datasets['test'],batch_size=64,shuffle=False,num_workers=4,persistent_workers=True,prefetch_factor=2,pin_memory=True)

else :
    train_dataloader=DataLoader(ModbusFlowStream( 
        shuffle=True,chunk_size=1,batch_size=64,csv_files=train_files,scalers=loaded_scalers['network-wide']['min_max_scalers'],window_size=window_size
    ),  batch_size=1,shuffle=False)
    val_dataloader=DataLoader(ModbusFlowStream( 
        shuffle=False,chunk_size=1,batch_size=64,csv_files=val_files,scalers=loaded_scalers['network-wide']['min_max_scalers'],window_size=window_size
    ),batch_size=1,shuffle=False)
    test_dataloader=DataLoader(ModbusFlowStream(shuffle=False,chunk_size=1,batch_size=64,csv_files=test_files,scalers=loaded_scalers['network-wide']['min_max_scalers'],window_size=window_size),
                               batch_size=1,shuffle=False)


In [8]:
print(len(train_dataloader),len(val_dataloader),len(test_dataloader))

44111 19158 1960


In [11]:
# train_eval(AE_model,AE_train_dataloader,AE_val_dataloader,AE_test_dataloader,shuffle_files=True,num_epochs=20,eval_epoch=1,criterion_method="bce",learning_rates=[5e-4],weight_decays=[0])
AE_model = AE(input_dim=76)
train_eval(AE_model,train_dataloader,val_dataloader,test_dataloader,shuffle_files=False,num_epochs=6,eval_epoch=1,criterion_method="mse",learning_rates=[5e-5,1e-4,1e-5,1e-6],weight_decays=[1e-4])



Train : time 163.55 s Epoch 1
Train Loss: 0.2204
--- Running Evaluation for Epoch 1 lr =5e-05 wd 0.0001 ---
-----------mse_loss mean :  0.0031 std: 0.0626
Val: Accuracy: 0.9751  
 K: 1 K-SIGMA Threshold : ---thr 0.06569
Counter({1: 65686, 0: 59746}) Counter({0: 89417, 1: 36015})
Counts of  labels: {0: 89417, 1: 35978, 2: 1, 3: 1, 4: 1, 5: 2, 6: 31, 7: 1}
Counts of misclassified original labels: {0: 29808, 1: 137}
Test : Accuracy: 0.7613 Recall : 0.9962 FDR: 0.4538  F1-score: 0.7056  
 K: 3 K-SIGMA Threshold : ---thr 0.1908
Counter({1: 65565, 0: 59867}) Counter({0: 89417, 1: 36015})
Counts of  labels: {0: 89417, 1: 35978, 2: 1, 3: 1, 4: 1, 5: 2, 6: 31, 7: 1}
Counts of misclassified original labels: {0: 29715, 1: 165}
Test : Accuracy: 0.7618 Recall : 0.9954 FDR: 0.4532  F1-score: 0.7058  
model AE is saved in best_models/AE_f1_0.71_recall_1.00_.pth
Train : time 174.78 s Epoch 2
Train Loss: 0.0026
--- Running Evaluation for Epoch 2 lr =5e-05 wd 0.0001 ---
-----------mse_loss mean :  0.00

In [11]:
VAE_model = VAE(input_dim=76)
train_eval(VAE_model,train_dataloader,val_dataloader,test_dataloader,shuffle_files=False,num_epochs=6,eval_epoch=1,criterion_method="mse",learning_rates=[1e-2,1e-3,1e-4],weight_decays=[1e-4])



Train : time 214.81 s Epoch 1
Train Loss: 1.3880
--- Running Evaluation for Epoch 1 lr =0.01 wd 0.0001 ---
-----------mse_loss mean :  0.5267 std: 1.3753
Val: Accuracy: 0.8128  
 K: 1 K-SIGMA Threshold : ---thr 1.902
Counter({0: 94388, 1: 31044}) Counter({0: 89417, 1: 36015})
Counts of  labels: {0: 89417, 1: 35978, 2: 1, 3: 1, 4: 1, 5: 2, 6: 31, 7: 1}
Counts of misclassified original labels: {0: 4751, 1: 9713, 5: 1, 6: 8}
Test : Accuracy: 0.8846 Recall : 0.7301 FDR: 0.1530  F1-score: 0.7842  
 K: 3 K-SIGMA Threshold : ---thr 4.653
Counter({0: 119349, 1: 6083}) Counter({0: 89417, 1: 36015})
Counts of  labels: {0: 89417, 1: 35978, 2: 1, 3: 1, 4: 1, 5: 2, 6: 31, 7: 1}
Counts of misclassified original labels: {0: 2430, 1: 32327, 2: 1, 3: 1, 4: 1, 5: 2, 6: 29, 7: 1}
Test : Accuracy: 0.7226 Recall : 0.1014 FDR: 0.3995  F1-score: 0.1735  
model VAE is saved in best_models/VAE_f1_0.78_recall_0.73_.pth
Train : time 213.60 s Epoch 2
Train Loss: 1.3741
--- Running Evaluation for Epoch 2 lr =0.01

In [13]:
VAE_model = VAE(input_dim=76)
train_eval(VAE_model,train_dataloader,val_dataloader,test_dataloader,shuffle_files=False,num_epochs=6,eval_epoch=1,criterion_method="mse",learning_rates=[1e-2,1e-3,1e-4],weight_decays=[1e-5],k_range=[0,1,3])



Train : time 219.01 s Epoch 1
Train Loss: 1.5069
--- Running Evaluation for Epoch 1 lr =0.01 wd 1e-05 ---
-----------mse_loss mean :  1.1421 std: 2.1842
Val: Accuracy: 0.8135  
 K: 0 K-SIGMA Threshold : ---thr 1.142
Counter({0: 87765, 1: 37667}) Counter({0: 89417, 1: 36015})
Counts of  labels: {0: 89417, 1: 35978, 2: 1, 3: 1, 4: 1, 5: 2, 6: 31, 7: 1}
Counts of misclassified original labels: {0: 11401, 1: 9740, 5: 1, 6: 8}
Test : Accuracy: 0.8314 Recall : 0.7293 FDR: 0.3027  F1-score: 0.7130  
 K: 1 K-SIGMA Threshold : ---thr 3.326
Counter({0: 113934, 1: 11498}) Counter({0: 89417, 1: 36015})
Counts of  labels: {0: 89417, 1: 35978, 2: 1, 3: 1, 4: 1, 5: 2, 6: 31, 7: 1}
Counts of misclassified original labels: {0: 8073, 1: 32555, 2: 1, 3: 1, 4: 1, 5: 2, 6: 29, 7: 1}
Test : Accuracy: 0.6758 Recall : 0.0951 FDR: 0.7021  F1-score: 0.1442  
 K: 3 K-SIGMA Threshold : ---thr 7.695
Counter({0: 120569, 1: 4863}) Counter({0: 89417, 1: 36015})
Counts of  labels: {0: 89417, 1: 35978, 2: 1, 3: 1, 4: 

In [15]:
AAE_model = AdversarialAutoencoder()
train_eval(AAE_model,train_dataloader,val_dataloader,test_dataloader,shuffle_files=False,num_epochs=6,eval_epoch=1,criterion_method="mse",learning_rates=[1e-2,1e-3,1e-4],weight_decays=[1e-4])



Train : time 267.13 s Epoch 1
Generator Loss: 491.9236 Discriminator Loss: 0.3799
--- Running Evaluation for Epoch 1 lr =0.01 wd 0.0001 ---
-----------mse_loss mean :  6.6088 std: 3.8604
Val: Accuracy: 0.3866  
 K: 1 K-SIGMA Threshold : ---thr 10.47
Counter({0: 117814, 1: 7618}) Counter({0: 89417, 1: 36015})
Counts of  labels: {0: 89417, 1: 35978, 2: 1, 3: 1, 4: 1, 5: 2, 6: 31, 7: 1}
Counts of misclassified original labels: {0: 4729, 1: 33091, 2: 1, 3: 1, 4: 1, 5: 2, 6: 29, 7: 1}
Test : Accuracy: 0.6982 Recall : 0.0802 FDR: 0.6208  F1-score: 0.1324  
 K: 3 K-SIGMA Threshold : ---thr 18.19
Counter({0: 125314, 1: 118}) Counter({0: 89417, 1: 36015})
Counts of  labels: {0: 89417, 1: 35978, 2: 1, 3: 1, 4: 1, 5: 2, 6: 31, 7: 1}
Counts of misclassified original labels: {0: 1, 1: 35861, 2: 1, 3: 1, 4: 1, 5: 2, 6: 31, 7: 1}
Test : Accuracy: 0.7138 Recall : 0.0032 FDR: 0.0085  F1-score: 0.0065  
model AdversarialAutoencoder is saved in best_models/AdversarialAutoencoder_f1_0.01_recall_0.08_.pth

#### Trained model evaluation on the compromised-ied and compromised scada scenarios 

No exact labeling for the comp ied scenario results in low performance 

In [57]:
Trained_AE_model=AE(input_dim=76)
Trained_AE_model.load_state_dict(torch.load("./best_models/AE_f1_0.84_recall_1.00_.pth"))
Trained_VAE_model=VAE(input_dim=76)
Trained_VAE_model.load_state_dict(torch.load("./best_models/VAE_f1_0.79_recall_0.81_.pth"))
Trained_AAE_model=AdversarialAutoencoder()
Trained_AAE_model.load_state_dict(torch.load("./best_models/AdversarialAutoencoder_f1_0.87_recall_1.00_.pth"))

<All keys matched successfully>

In [63]:
for scenario in {"compromised-scada","compromised-ied"}:
    print("scenario :",scenario,"node ied1b")
    val_files = [col for col in modbus.dataset["benign_dataset_dir"] if col.find("ied1b")!=-1][:]
    test_files= [col for col in modbus.dataset["attack_dataset_dir"][scenario] if col.find("ied1b")!=-1]
    #
    sys_rand = SystemRandom()
    sys_rand.shuffle(val_files)
    sys_rand.shuffle(test_files)
    print("----------- benign valid files:",len(val_files),val_files)
    print(f"----------{scenario} attack  test files : ",len(test_files),test_files)
    val_dataloader=DataLoader(ModbusFlowStream(
                shuffle=False,
                chunk_size=100,
                batch_size=64,
                csv_files=val_files,
                scalers=loaded_scalers['network-wide']['min_max_scalers'],
            ),batch_size=1,shuffle=False)
    test_dataloader=DataLoader(ModbusFlowStream(
                shuffle=False,
                chunk_size=100,
                batch_size=64,
                csv_files=test_files,
                scalers=loaded_scalers['network-wide']['min_max_scalers'],
            ),batch_size=1,shuffle=False)
    for trained_model in {Trained_AE_model,Trained_VAE_model,Trained_AAE_model}:
        print("*"*10,trained_model._get_name(),10*"*")
        train_eval(trained_model,None,val_dataloader,test_dataloader,shuffle_files=False,num_epochs=1,eval_epoch=1,criterion_method="mse",train_model=False,learning_rates=[0],weight_decays=[0])
        

scenario : compromised-scada node ied1b
----------- benign valid files: 7 ['dataset/ModbusDataset/benign/ied1b/ied1b-network-capture/ready/vethd9e14c0-normal-6-labeled.csv', 'dataset/ModbusDataset/benign/ied1b/ied1b-network-capture/ready/vethd9e14c0-normal-9-labeled.csv', 'dataset/ModbusDataset/benign/ied1b/ied1b-network-capture/ready/vethd9e14c0-normal-8-labeled.csv', 'dataset/ModbusDataset/benign/ied1b/ied1b-network-capture/ready/vethd9e14c0-normal-5-labeled.csv', 'dataset/ModbusDataset/benign/ied1b/ied1b-network-capture/ready/vethd9e14c0-normal-7-labeled.csv', 'dataset/ModbusDataset/benign/ied1b/ied1b-network-capture/ready/vethd9e14c0-normal-4-labeled.csv', 'dataset/ModbusDataset/benign/ied1b/ied1b-network-capture/ready/vethd9e14c0-normal-10-labeled.csv']
----------compromised-scada attack  test files :  8 ['dataset/ModbusDataset/attack/compromised-scada/ied1b/ied1b-network-captures/ready/vethc76bd3f-1-labeled.csv', 'dataset/ModbusDataset/attack/compromised-scada/ied1b/ied1b-network

### FedAvg - non iid distribution (ip based)

In [None]:
# ==============================================================================
# 1. SETUP: INSTALL LIBRARIES AND IMPORT DEPENDENCIES
# ==============================================================================
# In a Kaggle notebook, run this cell first to install the necessary libraries.
!pip install -q flwr[simulation] torch torchvision pandas scikit-learn matplotlib seaborn


In [16]:

import flwr as fl
from collections import OrderedDict
from typing import Dict, List, Tuple, Optional
import seaborn as sns
import os 
from flwr.common import Context # Make sure this import is added

# Suppress warning messages for a cleaner output
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
# Set a seed for reproducibility
SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed(SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

#global device
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Training on {DEVICE}")


Training on cuda:0


In [17]:

# ==============================================================================
# 2. CONFIGURATION: TWEAK  FEDERATED LEARNING EXPERIMENT
# ==============================================================================
class Config:
    """Global configuration class for the federated learning experiment."""
    # --- FL Parameters ---
    NUM_TRAIN_CLIENTS = 4
    NUM_ROUNDS = 10
    LOCAL_EPOCHS = 5
    BATCH_SIZE = 64
    LEARNING_RATE = 5e-5
    WEIGHT_DECAY = 1e-4
    # --- Strategy Selection ---
    # Choose from "FED_AVG", "FED_PROX"
    STRATEGY = "FED_AVG" 
    PROXIMAL_MU = 0.1 # Proximal term for FedProx
    # --- Model Selection ---
    # Choose from "AE" (Autoencoder) or "VAE" (Variational Autoencoder) or "AdverserialAutoencoder"
    MODEL_NAME = "AE"
    INPUT_DIM = 76
    # --- Anomaly Detection ---
    # ANOMALY_THRESHOLD is calculated dynamically in the evaluation function.
    EVAL_DATA_FILES = ["data_9.csv", "data_10.csv"]
    SHUFFLE_FILES=  True
# Instantiate the configuration
cfg = Config()



In [18]:

# ==============================================================================
# 3. DATA Distribution
# ==============================================================================


ied1b_train_files=[col for col in modbus.dataset["benign_dataset_dir"] if col.find("ied1b")!=-1]
ied1a_train_files=[col for col in modbus.dataset["benign_dataset_dir"] if col.find("ied1a")!=-1]
ied4c_train_files=[col for col in modbus.dataset["benign_dataset_dir"] if col.find("ied4c")!=-1]
scada_train_files=[col for col in modbus.dataset["benign_dataset_dir"] if col.find("scada-hmi")!=-1]
cent_agent_train_files=[col for col in modbus.dataset["benign_dataset_dir"] if col.find("central-agent")!=-1]

TRAIN_CLIENT_DATA_MAPPING = {
    0: ied1b_train_files,
    1: ied1a_train_files,
    2: ied4c_train_files,
    3: scada_train_files,
    4: cent_agent_train_files,
}

SERVER_EVALUATION_DATA_MAPPING = {
    0: val_files,
    1: test_files 
}

def load_data(id: int, node = "client" ):
    """Loads the data for a specific training client."""
    if node == "client":
        file_list = TRAIN_CLIENT_DATA_MAPPING[id]
        shuffle=cfg.SHUFFLE_FILES
    else: # server
        file_list = SERVER_EVALUATION_DATA_MAPPING[id]
        shuffle = False

    train_loader=DataLoader(ModbusFlowStream(
            shuffle=shuffle,
            chunk_size=1,
            batch_size=cfg.BATCH_SIZE ,
            csv_files=file_list,
            scalers=loaded_scalers['network-wide']['min_max_scalers'],
        ),batch_size=1,shuffle=False)
    return train_loader
def get_model():
    """Returns the model specified in the config."""
    if cfg.MODEL_NAME == "VAE":
        print(f"Using Variational Autoencoder (VAE) ")
        return VAE(input_dim=cfg.INPUT_DIM)
    elif cfg.MODEL_NAME == "AE":
        print(f"Using Autoencoder (AE) ")
        return AE(input_dim=cfg.INPUT_DIM)
    elif cfg.MODEL_NAME =="AdverserialAutoencoder":
        print(f"Using Adverserial Autoencoder (AAE) ")
        return AdverserialAutoencoder(input_dim=76)
    else:
        raise ValueError(f"Unknown model name: {cfg.MODEL_NAME}. Choose 'AE' or 'VAE' or 'AdverserialAutoencoder'.")


In [19]:

# ==============================================================================
# 5. FEDERATED LEARNING CLIENT: FlowerClient
# ==============================================================================
class FlowerClient(fl.client.NumPyClient):
    """Flower client for training."""
    def __init__(self, cid, model, trainloader):
        self.cid = cid
        self.model = model
        self.train_dataloader = trainloader

    def get_parameters(self, config):
        return [val.cpu().numpy() for _, val in self.model.state_dict().items()]

    def set_parameters(self, parameters):
        params_dict = zip(self.model.state_dict().keys(), parameters)
        state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
        self.model.load_state_dict(state_dict, strict=True)

    def fit(self, parameters, config):
        self.set_parameters(parameters)
        model =self.model
        lr = cfg.LEARNING_RATE
        wd= cfg.WEIGHT_DECAY
        
        criterion = nn.MSELoss(reduction='sum').to(DEVICE)
        if model._get_name()=="AdversarialAutoencoder":
            adversarial_criterion= nn.BCELoss(reduction="sum")
            optimizer_D = optim.Adam(model.discriminator.parameters(), lr=lr, weight_decay=wd)
            optimizer_G =  optim.Adam(list(model.encoder.parameters()) + list(model.decoder.parameters()), lr=lr, weight_decay=wd)
        else:
            AE_optimizer = optim.Adam(model.parameters(), lr=lr,weight_decay=wd)

        if cfg.STRATEGY == "FED_PROX":
            global_params = [torch.tensor(p, device=DEVICE) for p in parameters]

        for epoch in range(cfg.LOCAL_EPOCHS):
                time_1 = time.time()
                model.train()
                train_loss = 0
                ## for AAE
                Discriminator_loss = 0
                if cfg.SHUFFLE_FILES:
                    sys_rand = SystemRandom()
                    sys_rand.shuffle(self.train_dataloader.dataset.csv_files)
                for sequences, _ in self.train_dataloader:
                    sequences=sequences.squeeze().to(DEVICE)
                    if model._get_name()=="AdversarialAutoencoder":
                        target_ones= torch.ones(sequences.size(0), 1,device=DEVICE,dtype=torch.float)
                        target_zeros= torch.zeros(sequences.size(0), 1,device=DEVICE,dtype=torch.float)
                        random_latent = torch.randn(sequences.size(0), 2, device=DEVICE)
                        optimizer_G.zero_grad()
                        fake_z,decoded_seq = model(sequences)
                        G_loss = 0.001*adversarial_criterion(model.discriminator(fake_z),target_ones ) + 0.999*criterion(decoded_seq, sequences)
                        G_loss.backward()
                        optimizer_G.step()
                        # 2) discriminator loss
                        optimizer_D.zero_grad()
                        real_loss = adversarial_criterion(model.discriminator(random_latent), target_ones)
                        fake_loss = adversarial_criterion(model.discriminator(fake_z.detach()),  target_zeros)
                        D_loss =  0.001*0.5*(real_loss + fake_loss)
                        D_loss.backward()
                        optimizer_D.step()
                        train_loss+=G_loss.item()
                        Discriminator_loss+=D_loss.item()   
                    else:
                        AE_optimizer.zero_grad()
                        if model._get_name()=="AE":
                            recon = model(sequences)
                            loss = criterion(recon, sequences) / sequences.size(0)
                        elif model._get_name()=="VAE" or model._get_name()=="GRUVAE":
                            recon, mu, logvar = model(sequences)
                            loss = vae_loss_function(recon, sequences, mu, logvar) /sequences.size(0)
                            if cfg.STRATEGY == "FED_PROX":
                                proximal_term = 0.0
                                for local_w, global_w in zip(model.parameters(), global_params):
                                    proximal_term += (local_w - global_w).norm(2)
                                loss += (cfg.PROXIMAL_MU / 2) * proximal_term
                        loss.backward()
                        AE_optimizer.step()
                        train_loss += loss.item()
                print(f"Train : time {(time.time()-time_1):.2f} s",
                f"Epoch {epoch+1}")
                if model._get_name()=="AdversarialAutoencoder":
                    print(f"Generator Loss: {train_loss / len(self.train_dataloader):.4f}",
                        f"Discriminator Loss: {Discriminator_loss / len(self.train_dataloader):.4f}")
                else:
                    print(f"Train Loss: {train_loss / len(self.train_dataloader):.4f}")
        return self.get_parameters(config={}), len(self.trainloader.dataset), {}

    def evaluate(self, parameters, config):
        return 0.0, 0, {}


In [None]:
from flower.common import ndarrays_to_parameters, parameters_to_ndarrays



# ==============================================================================
# 6. SERVER-SIDE LOGIC AND SIMULATION START
# ==============================================================================
def client_fn(context: Context) -> FlowerClient:
    """Create a Flower client instance for a given client ID."""
    # The client's ID is retrieved from the context.
    client_id = int(context.node_config["partition-id"])
    model = get_model().to(DEVICE)
    trainloader = load_data(client_id)
    return FlowerClient(client_id, model, trainloader).to_client()

def get_evaluate_fn():
    """Return an evaluation function for server-side evaluation."""
    val_dataloader = load_data(0)
    test_dataloader = load_data(1)
    eval_criterion = nn.MSELoss(reduction='none').to(DEVICE)

    def evaluate(
        server_round: int,
        parameters: fl.common.NDArrays,
        config: Dict[str, fl.common.Scalar],
    ) -> Optional[Tuple[float, Dict[str, fl.common.Scalar]]]:
        model = get_model() # Use the get_model function
        params_dict = zip(model.state_dict().keys(), parameters)
        state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
        model.load_state_dict(state_dict, strict=True)
        model.to(DEVICE)
        model.eval()
        # Evaluate part
        all_val_losses = []
        all_val_labels = []
        print(f"--- Running Evaluation for Server round {server_round} ---")
        with torch.no_grad():
            for sequences, labels in val_dataloader:
                sequences = sequences.squeeze().to(DEVICE) 
                if labels.sum()!=0:
                    continue
                if model._get_name()=="AE":
                    recon = model(sequences)
                elif model._get_name()=="VAE" or model._get_name()=="GRUVAE" :
                    recon, _, _ = model(sequences)
                elif model._get_name()=="AdversarialAutoencoder":
                    _,recon= model(sequences)
                val_loss = eval_criterion(recon, sequences)
                if val_loss.dim() > 1:
                    val_loss = val_loss
                else:
                    val_loss = val_loss.unsqueeze(dim=0)
                    labels = labels.unsqueeze(dim=0)
                if val_loss.dim()==3:
                    ##GRU : mean of window
                    val_loss = val_loss.mean(dim=1)
                val_loss = val_loss.sum(dim=1)
                all_val_losses.extend(val_loss.cpu().numpy())
                all_val_labels.extend(labels.flatten().cpu().numpy())     
        threshold_1,std_mse = compute_threshold(all_val_losses,k=0)

        all_val_losses = np.array(all_val_losses).squeeze()  
        all_val_labels = np.array(all_val_labels).squeeze()  
        # If intrusion score > threshold, predict 1 (intrusion), else 0 (benign)
        # For FDR, get True Positives (TP) and False Positives (FP)
        
        predictions = (all_val_losses > threshold_1).astype(int)

        accuracy = accuracy_score(all_val_labels, predictions)
        print(f"Val: Accuracy: {accuracy:.4f}  ")
        model.eval() 

        all_test_losses = []
        all_test_labels = []
        with torch.no_grad():
            for sequences, labels in test_dataloader:
                sequences = sequences.squeeze().to(DEVICE)
                labels = labels.squeeze().to(DEVICE)
                if model._get_name()=="AE":
                    recon = model(sequences)
                elif model._get_name()=="VAE"  or model._get_name()=="GRUVAE":
                    recon, mu, logvar = model(sequences)
                elif model._get_name()=="AdversarialAutoencoder":
                    _,recon= model(sequences)

                intrusion_scores = eval_criterion(recon, sequences)
                if intrusion_scores.dim() > 1:
                    intrusion_scores = intrusion_scores
                else:
                    intrusion_scores = intrusion_scores.unsqueeze(dim=0)
                    labels = labels.unsqueeze(dim=0)
                if intrusion_scores.dim()==3:
                    ##GRU : mean of window
                    intrusion_scores = intrusion_scores.mean(dim=1)
                intrusion_scores = intrusion_scores.sum(dim=1)
                all_test_losses.extend(intrusion_scores.cpu().numpy())
                all_test_labels.extend(labels.cpu().numpy())

        all_test_losses = np.array(all_test_losses)
        all_test_labels = np.array(all_test_labels)
        test_result = {}
        for k in {1,3}:
            threshold=threshold_1+k*std_mse
            print(f" K: {k} K-SIGMA Threshold : ---thr {threshold:.4}")
            predictions = (all_test_losses > threshold).astype(int)
            binary_test_labels = (all_test_labels != 0).astype(int)

            # Find the indices where the prediction was incorrect
            misclassified_indices = np.where(binary_test_labels != predictions)[0]

            # Get the original labels for those misclassified instances
            misclassified_original_labels = all_test_labels[misclassified_indices]

            # To get a summary count of which labels were misclassified
            print(Counter(predictions),Counter(binary_test_labels))
            print(f"Counts of  labels: {dict(sorted(Counter(all_test_labels).items()))}")
            print(f"Counts of misclassified original labels: {dict(sorted(Counter(misclassified_original_labels).items()))}")
            accuracy = accuracy_score(binary_test_labels, predictions)
            f1 = f1_score(binary_test_labels, predictions, zero_division=0)
            recall = recall_score(binary_test_labels, predictions,zero_division=0)
            _, fp, _, tp = confusion_matrix(binary_test_labels, predictions, labels=[0, 1]).ravel()
            # FDR = FP / (FP + TP) 
            if (fp + tp) == 0:
                fdr = 0.0 
            else:
                fdr = fp / (fp + tp)
            test_result[k] = f"k= {k} ,Test : Accuracy: {accuracy:.4f} Recall : {recall:.4f} FDR: {fdr:.4f}  F1-score: {f1:.4f} "
            print(test_result[k])
        return np.sum(all_test_losses)/len(all_test_losses),test_result
    return evaluate

def get_initial_parameters(model: torch.nn.Module):
    """
    Initializes the model weights using Xavier uniform distribution
    and returns them as a Flower Parameters object.
    """
    temp_model = model()
    
    for param in temp_model.parameters():
        if param.dim() > 1:
            nn.init.xavier_uniform_(param)
            
    ndarrays = [val.cpu().numpy() for _, val in temp_model.state_dict().items()]
    return ndarrays_to_parameters(ndarrays)

In [None]:
import ray
# --- Select the Federation Strategy ---
evaluate_function = get_evaluate_fn()

if cfg.STRATEGY == "FED_PROX":
    strategy = fl.server.strategy.FedProx(
        fraction_fit=1.0, fraction_evaluate=0.0,
        min_fit_clients=cfg.NUM_TRAIN_CLIENTS,
        min_available_clients=cfg.NUM_TRAIN_CLIENTS,
        evaluate_fn=evaluate_function,
        proximal_mu=cfg.PROXIMAL_MU,
        initial_parameters=get_initial_parameters(cf)
    )
    print("Using FedProx strategy.")
else:
    strategy = fl.server.strategy.FedAvg(
        fraction_fit=1.0, fraction_evaluate=0.0,
        min_fit_clients=cfg.NUM_TRAIN_CLIENTS,
        min_available_clients=cfg.NUM_TRAIN_CLIENTS,
        evaluate_fn=evaluate_function,
    )
    print(f"Using FedAvg strategy with {cfg.MODEL_NAME} model.")

# --- Start the Simulation ---
print("Starting federated learning simulation...")
history = fl.simulation.start_simulation(
    client_fn=client_fn,
    num_clients=cfg.NUM_TRAIN_CLIENTS,
    config=fl.server.ServerConfig(num_rounds=cfg.NUM_ROUNDS),
    strategy=strategy,
    client_resources={"num_cpus": 4, "num_gpus": 1/cfg.NUM_TRAIN_CLIENTS} if DEVICE.type == "cuda" else {"num_cpus": 4},
)
print("Federated learning simulation finished.")

	Instead, use the `flwr run` CLI command to start a local simulation in your Flower app, as shown for example below:

		$ flwr new  # Create a new Flower app from a template

		$ flwr run  # Run the Flower app in Simulation Mode

	Using `start_simulation()` is deprecated.

            This is a deprecated feature. It will be removed
            entirely in future versions of Flower.
        
[92mINFO [0m:      Starting Flower simulation, config: num_rounds=10, no round_timeout


Using FedAvg strategy with AE model.
Starting federated learning simulation...


2025-07-17 13:33:32,745	INFO worker.py:1771 -- Started a local Ray instance.
[92mINFO [0m:      Flower VCE: Ray initialized with resources: {'accelerator_type:G': 1.0, 'node:__internal_head__': 1.0, 'node:172.27.10.149': 1.0, 'CPU': 4.0, 'memory': 6197892711.0, 'object_store_memory': 3098946355.0, 'GPU': 1.0}
[92mINFO [0m:      Optimize your simulation with Flower VCE: https://flower.ai/docs/framework/how-to-run-simulations.html
[92mINFO [0m:      Flower VCE: Resources for each Virtual Client: {'num_cpus': 4, 'num_gpus': 0.25}
[92mINFO [0m:      Flower VCE: Creating VirtualClientEngineActorPool with 1 actors
[92mINFO [0m:      [INIT]
[92mINFO [0m:      Requesting initial parameters from one random client
[36m(ClientAppActor pid=119723)[0m Pyarrow will become a required dependency of pandas in the next major release of pandas (pandas 3.0),
[36m(ClientAppActor pid=119723)[0m (to allow more performant data types, such as the Arrow string type, and better interoperability w

[36m(ClientAppActor pid=119723)[0m Using Autoencoder (AE) 


[92mINFO [0m:      Received initial parameters from one random client
[92mINFO [0m:      Starting evaluation of initial global parameters


Using Autoencoder (AE) 
--- Running Evaluation for Server round 0 ---
-----------mse_loss mean :  17.2565 std: 0.0250
Val: Accuracy: 0.5170  
 K: 1 K-SIGMA Threshold : ---thr 17.28
Counter({0: 710557, 1: 515282}) Counter({0: 1225839})
Counts of  labels: {0: 1225839}
Counts of misclassified original labels: {0: 515282}
k= 1 ,Test : Accuracy: 0.5796 Recall : 0.0000 FDR: 1.0000  F1-score: 0.0000 
 K: 3 K-SIGMA Threshold : ---thr 17.33
Counter({0: 912970, 1: 312869}) Counter({0: 1225839})
Counts of  labels: {0: 1225839}
Counts of misclassified original labels: {0: 312869}


[92mINFO [0m:      initial parameters (loss, other metrics): 16.642735302107372, {1: 'k= 1 ,Test : Accuracy: 0.5796 Recall : 0.0000 FDR: 1.0000  F1-score: 0.0000 ', 3: 'k= 3 ,Test : Accuracy: 0.7448 Recall : 0.0000 FDR: 1.0000  F1-score: 0.0000 '}
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 1]
[92mINFO [0m:      configure_fit: strategy sampled 4 clients (out of 4)


k= 3 ,Test : Accuracy: 0.7448 Recall : 0.0000 FDR: 1.0000  F1-score: 0.0000 
[36m(ClientAppActor pid=119723)[0m Using Autoencoder (AE) 
