### **Section 1: Data Ingestion and Initial Consolidation**
This section handles the initial loading of the IoT-DIAD Dataset, which is spread across multiple CSV files. It mounts Google Drive, recursively finds all `.csv` files, and then consolidates them into a single, large HDF5 file for easier processing. During this process, column names are standardized to ensure consistency.

In [None]:
from google.colab import drive
import glob
import os

# 1. Mount Drive
drive.mount('/content/drive')

# --- EDIT THIS ---
# 2. Point this to the main folder on your Drive
DRIVE_DATA_FOLDER = '/content/drive/MyDrive/FL_Project/IoT-DIAD-Dataset'
# -----------------

print(f"Looking for CSVs in: {DRIVE_DATA_FOLDER}")

# 3. Find all .csv files recursively
# This will search inside 'Benign', 'Brute Force', etc.
all_csv_files = glob.glob(f"{DRIVE_DATA_FOLDER}/**/*.csv", recursive=True)

if not all_csv_files:
    print("ðŸš¨ ERROR: No CSV files found. Check your DRIVE_DATA_FOLDER path.")
else:
    print(f"Found {len(all_csv_files)} CSV files to process.")

Mounted at /content/drive
Looking for CSVs in: /content/drive/MyDrive/FL_Project/IoT-DIAD-Dataset
Found 23 CSV files to process.


### **Section 2: Data Cleaning and Preprocessing**
After consolidating the raw data, this section focuses on cleaning and preparing the dataset for model training. It involves several critical steps:
- **Dropping Unnecessary Columns**: Identifying and removing columns that are either identifiers (e.g., `flow_id`, `timestamp`, `src_ip`, `dst_ip`) or less relevant for a machine learning model.
- **Handling Missing Values**: Removing rows that contain any null values to ensure data integrity, assuming a sufficient amount of data remains.
- **Replacing Infinite Values**: Converting any infinity values (which often arise from division by zero) to zero to prevent errors in numerical computations.

This stage produces a `master_cleaned_dataset.h5` file.

In [None]:
import pandas as pd
import re
import os

# --- EDIT THIS ---
MASTER_HDF_FILE = '/content/drive/MyDrive/FL_Project/master_labeled_dataset.h5'
# -----------------
HDF_KEY = 'data'

def sanitize_column_name(col_name):
    """Cleans column names to be HDF5-safe."""
    col_name = col_name.strip()
    col_name = re.sub(r'[^a-zA-Z0-9_]+', '_', col_name)
    col_name = col_name.lower()
    return col_name

# Clean up any old file first
if os.path.exists(MASTER_HDF_FILE):
    os.remove(MASTER_HDF_FILE)
    print(f"Removed old version of {MASTER_HDF_FILE}")

print("Starting FAST 'all-at-once' aggregation and sanitization...")

# We will store all DataFrames in this list
all_dataframes = []
master_columns = None
total_rows = 0

for i, filepath in enumerate(all_csv_files):
    try:
        print(f"  Processing file {i+1}/{len(all_csv_files)}: {filepath}")

        # Read the ENTIRE file at once. No chunking.
        # low_memory=False can speed up reading of mixed-type columns.
        df_temp = pd.read_csv(filepath, low_memory=False)

        if df_temp.empty:
             print(f"  Skipping empty file: {filepath}")
             continue

        # --- Sanitize Column Names ---
        clean_cols = [sanitize_column_name(col) for col in df_temp.columns]
        df_temp.columns = clean_cols

        # On the very first file, store the clean column names
        if master_columns is None:
            master_columns = df_temp.columns

        # --- Schema Mismatch Check ---
        # Ensure all other files match the first one
        if list(df_temp.columns) != list(master_columns):
            print(f"  WARNING: Column mismatch in {filepath}. Re-aligning columns.")
            # This will add missing columns (with NaN) and drop extra ones
            df_temp = df_temp.reindex(columns=master_columns)

        all_dataframes.append(df_temp)
        total_rows += len(df_temp)

    except Exception as e:
        print(f"  ðŸš¨ ERROR processing {filepath}: {e}. Skipping file.")

print("\n----------------------------------")
print("All files loaded into memory. Now concatenating...")

# --- This is the single "big memory" step ---
# It combines all the dataframes in the list into one.
df_master = pd.concat(all_dataframes, ignore_index=True)

print(f"Concatenation complete. Total rows: {total_rows}")
print(f"Final shape: {df_master.shape}")

print(f"Saving to HDF5 file: {MASTER_HDF_FILE}...")
# Save the single, final dataframe
df_master.to_hdf(
    MASTER_HDF_FILE,
    HDF_KEY,
    format='table',
    mode='w' # 'w' for write (not append)
)

print("\nâœ… All files merged and SANITIZED into one HDF5 file!")
print(f"Master file saved at: {MASTER_HDF_FILE}")
print("----------------------------------")

Removed old version of /content/drive/MyDrive/FL_Project/master_labeled_dataset.h5
Starting FAST 'all-at-once' aggregation and sanitization...
  Processing file 1/23: /content/drive/MyDrive/FL_Project/IoT-DIAD-Dataset/Spoofing/DNS_Spoofing.pcap_Flow.csv
  Processing file 2/23: /content/drive/MyDrive/FL_Project/IoT-DIAD-Dataset/Spoofing/MITM-ArpSpoofing.pcap_Flow.csv
  Processing file 3/23: /content/drive/MyDrive/FL_Project/IoT-DIAD-Dataset/Recon/VulnerabilityScan.pcap_Flow.csv
  Processing file 4/23: /content/drive/MyDrive/FL_Project/IoT-DIAD-Dataset/Web-Based/SqlInjection.pcap_Flow.csv
  Processing file 5/23: /content/drive/MyDrive/FL_Project/IoT-DIAD-Dataset/Web-Based/Uploading_Attack.pcap_Flow.csv
  Processing file 6/23: /content/drive/MyDrive/FL_Project/IoT-DIAD-Dataset/Web-Based/XSS.pcap_Flow.csv
  Processing file 7/23: /content/drive/MyDrive/FL_Project/IoT-DIAD-Dataset/Mirai/Mirai-greeth_flood1.pcap_Flow.csv
  Processing file 8/23: /content/drive/MyDrive/FL_Project/IoT-DIAD-Datas

  df_master.to_hdf(



âœ… All files merged and SANITIZED into one HDF5 file!
Master file saved at: /content/drive/MyDrive/FL_Project/master_labeled_dataset.h5
----------------------------------


In [None]:
df_master.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 1619747 entries, 0 to 1619746
Data columns (total 84 columns):
 #   Column                      Non-Null Count    Dtype  
---  ------                      --------------    -----  
 0   flow_id                     1619747 non-null  object 
 1   src_ip                      1619747 non-null  object 
 2   src_port                    1619747 non-null  int64  
 3   dst_ip                      1619747 non-null  object 
 4   dst_port                    1619747 non-null  int64  
 5   protocol                    1619747 non-null  int64  
 6   timestamp                   1619747 non-null  object 
 7   flow_duration               1619747 non-null  int64  
 8   total_fwd_packet            1619747 non-null  int64  
 9   total_bwd_packets           1619747 non-null  int64  
 10  total_length_of_fwd_packet  1619747 non-null  float64
 11  total_length_of_bwd_packet  1619747 non-null  float64
 12  fwd_packet_length_max       1619747 non-null  float64
 1

### **Section 3: Feature Engineering and Scaling**
This section prepares the cleaned data for machine learning by transforming raw features into a suitable format. It involves:
- **Identifying Feature Types**: Distinguishing between numeric and categorical features.
- **Scaling Numeric Features**: Applying `MinMaxScaler` to numeric features to normalize their range, which is crucial for many machine learning algorithms.
- **Hashing Categorical Features**: Using `FeatureHasher` to convert high-cardinality categorical features into a fixed-size numerical vector, reducing dimensionality while retaining information.
- **Saving Preprocessors**: Storing the fitted `MinMaxScaler` and `FeatureHasher` to ensure consistency when processing new, unseen data.

The output is a `final_processed_data.npz` file containing the scaled feature matrix and original string labels.

In [None]:
import pandas as pd
import numpy as np
import os

# --- Configuration ---
# 1. The master file you just created
MASTER_HDF_FILE = '/content/drive/MyDrive/FL_Project/master_labeled_dataset.h5'
HDF_KEY = 'data'

# 2. The *new* clean file we will create
CLEANED_HDF_FILE = '/content/drive/MyDrive/FL_Project/master_cleaned_dataset.h5'

# 3. Columns to drop
# These are identifiers or high-cardinality features that are
# not good for a simple DNN. We also drop the string labels
# we created manually.
UNNECESSARY_COLS = [
    'flow_id',
    'timestamp',
    'src_ip',
    'dst_ip',
]
# ---------------------

print(f"Loading master file: {MASTER_HDF_FILE}")
# 1. Load the entire 1GB file into memory
df_master = pd.read_hdf(MASTER_HDF_FILE, HDF_KEY)
print("Load complete.")

print("\n--- BEFORE CLEANING ---")
print(df_master.info())

# 2. Drop unnecessary columns
# We will keep 'label' for now to do our non-IID split later
cols_to_drop = [col for col in UNNECESSARY_COLS if col in df_master.columns]
df_master.drop(columns=cols_to_drop, inplace=True)
print(f"\nDropped {len(cols_to_drop)} unnecessary columns.")

# 3. Drop all rows with ANY null/missing values
# As you said, we have enough data to do this.
initial_rows = len(df_master)
df_master.dropna(inplace=True)
final_rows = len(df_master)
print(f"Dropped {initial_rows - final_rows} rows with null values.")

# 4. Replace any lingering infinity values (from division by zero)
df_master.replace([np.inf, -np.inf], 0, inplace=True)
print("Replaced all -inf/+inf values with 0.")

print("\n--- AFTER CLEANING ---")
print(f"Final shape: {df_master.shape}")
print(df_master.info())

# 5. Save the new, clean DataFrame
print(f"\nSaving clean data to: {CLEANED_HDF_FILE}")
df_master.to_hdf(
    CLEANED_HDF_FILE,
    HDF_KEY,
    format='table',
    mode='w'
)

print("âœ… Clean file saved to Google Drive.")

Loading master file: /content/drive/MyDrive/FL_Project/master_labeled_dataset.h5
Load complete.

--- BEFORE CLEANING ---
<class 'pandas.core.frame.DataFrame'>
Index: 1619747 entries, 0 to 1619746
Data columns (total 84 columns):
 #   Column                      Non-Null Count    Dtype  
---  ------                      --------------    -----  
 0   flow_id                     1619747 non-null  object 
 1   src_ip                      1619747 non-null  object 
 2   src_port                    1619747 non-null  int64  
 3   dst_ip                      1619747 non-null  object 
 4   dst_port                    1619747 non-null  int64  
 5   protocol                    1619747 non-null  int64  
 6   timestamp                   1619747 non-null  object 
 7   flow_duration               1619747 non-null  int64  
 8   total_fwd_packet            1619747 non-null  int64  
 9   total_bwd_packets           1619747 non-null  int64  
 10  total_length_of_fwd_packet  1619747 non-null  float64
 11 

  df_master.to_hdf(


âœ… Clean file saved to Google Drive.


### **Section 4: Feature Reduction**
To optimize model performance, reduce training time, and mitigate the curse of dimensionality, this section employs two feature reduction techniques:
- **Correlation Filtering**: Eliminating highly correlated features to avoid multicollinearity and redundancy.
- **Tree-Based Feature Selection (LightGBM)**: Utilizing a LightGBM classifier to rank features by importance and select the top `K` most relevant features for the given task.

The chosen features form a 'pipeline' that is saved, along with the `final_reduced_data.npz` file, ready for the next stage.

In [None]:
import pandas as pd
import numpy as np
from sklearn.preprocessing import MinMaxScaler
from sklearn.feature_extraction import FeatureHasher
import pickle
import os

# --- Configuration ---
# 1. The clean file from Cell 3
CLEANED_HDF_FILE = '/content/drive/MyDrive/FL_Project/master_cleaned_dataset.h5'
HDF_KEY = 'data'

# 2. The path to save your *fitted* preprocessors
PREPROCESSOR_PATH = '/content/drive/MyDrive/FL_Project/preprocessors.pkl'

# 3. The path to save your *final, processed* data
FINAL_PROCESSED_FILE = '/content/drive/MyDrive/FL_Project/final_processed_data.npz'
# ---------------------

print(f"Loading clean data from: {CLEANED_HDF_FILE}")
# 1. Load the clean dataset
df_clean = pd.read_hdf(CLEANED_HDF_FILE, HDF_KEY)
print("Load complete.")

# --- 2. Identify Feature Columns Programmatically ---

# a. Define the label column
LABEL_COL = 'label'

# b. Define known categorical features (even if they are 'int' type)
#    These are features that represent categories, not quantities.
KNOWN_CATEGORICAL = [
    'src_port',
    'dst_port',
    'protocol',
    'fwd_psh_flags',
    'bwd_psh_flags',
    'fwd_urg_flags',
    'bwd_urg_flags',
    'fin_flag_count',
    'syn_flag_count',
    'rst_flag_count',
    'psh_flag_count',
    'ack_flag_count',
    'urg_flag_count',
    'cwr_flag_count',
    'ece_flag_count'
]

# Find which of these known categoricals actually exist in our clean DataFrame
CATEGORICAL_COLS = [col for col in KNOWN_CATEGORICAL if col in df_clean.columns]

# c. Numeric columns are everything else that is a number
all_numeric_cols = df_clean.select_dtypes(include=[np.number]).columns
NUMERIC_COLS = list(set(all_numeric_cols) - set(CATEGORICAL_COLS))

# d. This is the string label we will use for splitting
Y_LABELS_STRING = df_clean[LABEL_COL]

print(f"\nIdentified {len(NUMERIC_COLS)} numeric features.")
print(f"Identified {len(CATEGORICAL_COLS)} categorical features.")
print(f"Identified {LABEL_COL} as the label.")

# --- 3. Fit and Save Preprocessors ---
print("\nFitting preprocessors on clean data...")

# a. Fit Scaler for numeric data
scaler = MinMaxScaler()
scaler.fit(df_clean[NUMERIC_COLS])
print("  MinMaxScaler fitted.")

# b. Initialize Hasher for categorical data
N_FEATURES_HASH = 10  # We'll hash all categorical features down to 10
hasher = FeatureHasher(n_features=N_FEATURES_HASH, input_type='string')
print(f"  FeatureHasher initialized with {N_FEATURES_HASH} features.")

# c. Save the fitted preprocessors to Google Drive
with open(PREPROCESSOR_PATH, 'wb') as f:
    pickle.dump({'scaler': scaler, 'hasher': hasher,
                 'numeric_cols': NUMERIC_COLS,
                 'categorical_cols': CATEGORICAL_COLS}, f)
print(f"Preprocessors saved to: {PREPROCESSOR_PATH}")

# --- 4. Transform the Entire Dataset ---
print("\nTransforming the full dataset in memory...")

# a. Scale all numeric data
scaled_numeric_data = scaler.transform(df_clean[NUMERIC_COLS])
print(f"  Numeric data scaled. Shape: {scaled_numeric_data.shape}")

# b. Hash all categorical data
#    Convert all to string, then to a dict for the hasher
hashed_categorical_data = hasher.transform(df_clean[CATEGORICAL_COLS].astype(str).to_dict('records'))
print(f"  Categorical data hashed. Shape: {hashed_categorical_data.shape}")

# c. Combine into your final feature matrix 'X_processed'
X_processed = np.hstack((scaled_numeric_data, hashed_categorical_data.toarray()))

print(f"\nFinal feature matrix 'X_processed' created.")
print(f"  Final X shape: {X_processed.shape}")
print(f"  Final y shape: {Y_LABELS_STRING.shape}")

# --- 5. Save Final Processed Data ---
# This saves our work before the final split, a very good checkpoint.
print(f"\nSaving final processed data to: {FINAL_PROCESSED_FILE}")
np.savez_compressed(
    FINAL_PROCESSED_FILE,
    X=X_processed,
    y_str=Y_LABELS_STRING.values  # Save the string labels for splitting
)

print("\n----------------------------------")
print("âœ… Cell 4 Complete!")
print("Your data is now fully preprocessed and saved.")
print("You are ready for the final step: Cell 5 (Splitting).")
print("----------------------------------")

Loading clean data from: /content/drive/MyDrive/FL_Project/master_cleaned_dataset.h5
Load complete.

Identified 64 numeric features.
Identified 15 categorical features.
Identified label as the label.

Fitting preprocessors on clean data...
  MinMaxScaler fitted.
  FeatureHasher initialized with 10 features.
Preprocessors saved to: /content/drive/MyDrive/FL_Project/preprocessors.pkl

Transforming the full dataset in memory...
  Numeric data scaled. Shape: (1618373, 64)
  Categorical data hashed. Shape: (1618373, 10)

Final feature matrix 'X_processed' created.
  Final X shape: (1618373, 74)
  Final y shape: (1618373,)

Saving final processed data to: /content/drive/MyDrive/FL_Project/final_processed_data.npz

----------------------------------
âœ… Cell 4 Complete!
Your data is now fully preprocessed and saved.
You are ready for the final step: Cell 5 (Splitting).
----------------------------------


In [None]:
# First, we need to install LightGBM, which is a very fast
# and powerful tree-based model for feature selection.
!pip install lightgbm

import pandas as pd
import numpy as np
import pickle
import os
import lightgbm as lgbm

# --- Configuration ---
# 1. The full, processed data file from Cell 4
FINAL_PROCESSED_FILE = '/content/drive/MyDrive/FL_Project/final_processed_data.npz'

# 2. The preprocessor file (to get our feature names)
PREPROCESSOR_PATH = '/content/drive/MyDrive/FL_Project/preprocessors.pkl'

# 3. The path to save our "pipeline" (the list of selected features)
SELECTED_FEATURES_FILE = '/content/drive/MyDrive/FL_Project/selected_features.pkl'

# 4. The final .npz file with *reduced* features
FINAL_REDUCED_DATA_FILE = '/content/drive/MyDrive/FL_Project/final_reduced_data.npz'

# 5. Reduction Parameters
CORR_THRESHOLD = 0.95  # Drop features with > 95% correlation
TOP_K_FEATURES = 30    # Select the 30 best features from the model
# ---------------------

print(f"Loading full processed data from: {FINAL_PROCESSED_FILE}")
# 1. Load data from Cell 4
# FIX: Add allow_pickle=True to load object arrays (y_str)
data = np.load(FINAL_PROCESSED_FILE, allow_pickle=True)
X_processed = data['X']
y_labels_string = data['y_str']

print(f"Loading preprocessor info from: {PREPROCESSOR_PATH}")
# 2. Load preprocessor to get feature names
with open(PREPROCESSOR_PATH, 'rb') as f:
    preprocessors = pickle.load(f)

NUMERIC_COLS = preprocessors['numeric_cols']
CATEGORICAL_COLS = preprocessors['categorical_cols']
N_FEATURES_HASH = 10 # This must match N_FEATURES_HASH from Cell 4

# 3. Reconstruct the Full Feature Name List
# This is critical for interpreting our results.
hash_feature_names = [f'hash_{i}' for i in range(N_FEATURES_HASH)]
all_feature_names = NUMERIC_COLS + hash_feature_names

print(f"Original feature count: {len(all_feature_names)}")

# 4. Create DataFrame for analysis
# This is necessary for correlation and easy filtering.
df_processed = pd.DataFrame(X_processed, columns=all_feature_names)

# --- 5. Tech 1: Correlation Filtering ---
print(f"\nRunning Tech 1: Correlation Filtering (Threshold={CORR_THRESHOLD})...")
corr_matrix = df_processed.corr().abs()
upper_tri = corr_matrix.where(np.triu(np.ones(corr_matrix.shape), k=1).astype(bool))
cols_to_drop_corr = [col for col in upper_tri.columns if any(upper_tri[col] > CORR_THRESHOLD)]

print(f"  Found {len(cols_to_drop_corr)} highly correlated features to drop.")
print(f"  Dropped features: {cols_to_drop_corr}")

# Drop these columns to create our first filtered dataset
df_filtered_1 = df_processed.drop(columns=cols_to_drop_corr)
print(f"  Features remaining after correlation filtering: {len(df_filtered_1.columns)}")

# --- 6. Tech 2: Tree-Based Selection ---
print(f"\nRunning Tech 2: Tree-Based Selection (Top K={TOP_K_FEATURES})...")

# a. Convert string labels to numbers (e.g., 'Benign'=0, 'DOS'=1, etc.)
# This is needed to train the classifier.
y_numeric, _ = pd.factorize(y_labels_string)

# b. Initialize and train a LightGBM classifier
# LGBM is very fast and memory-efficient.
lgbm_selector = lgbm.LGBMClassifier(n_estimators=100, random_state=42, n_jobs=-1)

print("  Training feature selection model...")
lgbm_selector.fit(df_filtered_1, y_numeric)
print("  Model training complete.")

# c. Get feature importances
importances = lgbm_selector.feature_importances_
importance_df = pd.DataFrame({
    'feature': df_filtered_1.columns,
    'importance': importances
}).sort_values(by='importance', ascending=False)

# d. Select the Top K features
final_selected_features = list(importance_df.head(TOP_K_FEATURES)['feature'])

print(f"\n--- Top {TOP_K_FEATURES} Selected Features ---")
print(importance_df.head(TOP_K_FEATURES))
print("---------------------------------")

# --- 7. Save the \"Pipeline\" and Final Data ---

# a. Save the list of feature names. This *is* your \"pipeline\".
# In a real app, you would load this list and use it to filter
# new, incoming data *after* it has been scaled and hashed.
with open(SELECTED_FEATURES_FILE, 'wb') as f:
    pickle.dump(final_selected_features, f)
print(f"\nFeature selection 'pipeline' (list of names) saved to: {SELECTED_FEATURES_FILE}")

# b. Filter your full dataset to *only* these final features
X_final_reduced = df_filtered_1[final_selected_features].values

print(f"\nFinal feature matrix shape: {X_final_reduced.shape}")

# c. Save the new, reduced X matrix and the original string labels
np.savez_compressed(
    FINAL_REDUCED_DATA_FILE,
    X=X_final_reduced,
    y_str=y_labels_string  # We save the original string labels for Cell 5
)

print(f"Final REDUCED dataset saved to: {FINAL_REDUCED_DATA_FILE}")
print("\n----------------------------------")
print("âœ… Cell 4.5 (Feature Reduction) Complete!")
print("You are now ready for the final step: Cell 5 (Splitting).")
print("----------------------------------")

Loading full processed data from: /content/drive/MyDrive/FL_Project/final_processed_data.npz
Loading preprocessor info from: /content/drive/MyDrive/FL_Project/preprocessors.pkl
Original feature count: 74

Running Tech 1: Correlation Filtering (Threshold=0.95)...
  Found 16 highly correlated features to drop.
  Dropped features: ['packet_length_mean', 'idle_min', 'idle_max', 'flow_iat_max', 'fwd_packet_length_mean', 'average_packet_size', 'active_mean', 'bwd_packet_length_mean', 'fwd_packets_s', 'fwd_iat_min', 'fwd_iat_total', 'packet_length_variance', 'idle_mean', 'flow_iat_min', 'fwd_packet_length_min', 'flow_iat_mean']
  Features remaining after correlation filtering: 58

Running Tech 2: Tree-Based Selection (Top K=30)...
  Training feature selection model...
[LightGBM] [Info] Auto-choosing row-wise multi-threading, the overhead of testing was 0.413492 seconds.
You can set `force_row_wise=true` to remove the overhead.
And if memory is not enough, you can set `force_col_wise=true`.
[L

### **Section 5: Non-IID Data Splitting for Federated Learning**
This crucial section prepares the dataset for a Federated Learning (FL) setup by creating non-IID (non-independently and identically distributed) client datasets. This simulates real-world scenarios where clients may have different data distributions.

The steps include:
- **Global Test Set Creation**: First, a portion of the total data is set aside as a global test set, stratified by labels to ensure all attack types are represented. This set is for server-side evaluation of the final federated model.
- **Client Data Partitioning**: The remaining data is then partitioned into client-specific datasets based on predefined non-IID criteria (e.g., specific attack types assigned to 'hospital' and 'factory' clients).
- **Local Train/Test Splits**: Each client's data is further split into local training and testing sets, also stratified by labels.
- **Binary Label Conversion**: Attack labels are converted into a binary format (0 for 'Benign', 1 for 'Attack') suitable for binary classification models.
- **Saving Client Data**: The processed client training and testing sets, along with the global test set, are saved as compressed `.npz` files, ready to be distributed to FL clients.

In [3]:
import pandas as pd
import numpy as np
import os
from sklearn.model_selection import train_test_split

# --- Configuration ---
# 1. The final reduced data file from your last step (now an NPZ)
FINAL_REDUCED_DATA_FILE = '/content/drive/MyDrive/FL_Project/final_reduced_data.npz'

# 2. The final output path for your client .npz files
OUTPUT_NPZ_PATH = '/content/drive/MyDrive/FL_Project/client_data/'

# 3. How much data to hold out for the final "Global" test set
GLOBAL_TEST_SET_SIZE = 0.1  # 10%

# 4. --- This is your Non-IID Split Logic ---
NORMAL_LABEL = 'Benign'
CLIENT_1_LABELS = ['Benign', 'Spoofing', 'Web-Based'] # "Hospital"
CLIENT_2_LABELS = ['Benign', 'Brute Force', 'DOS', 'DDOS', 'Mirai', 'Recon'] # "Factory"
# -----------------------------------------------

print(f"Loading final reduced dataset from: {FINAL_REDUCED_DATA_FILE}")
# 1. Load your final, reduced data (X and y_str) from the NPZ file
data = np.load(FINAL_REDUCED_DATA_FILE, allow_pickle=True)
X_final_reduced = data['X']
y_labels_string = data['y_str']

# Reconstruct a DataFrame for easier handling with existing code that expects it
df_final = pd.DataFrame(X_final_reduced)
df_final['label'] = y_labels_string # Add the label column back

print("Load complete.")

# --- 2. NEW: Create the Global Test Set FIRST ---
print(f"\nCreating Global Test Set (Size: {GLOBAL_TEST_SET_SIZE * 100}%) ")

# We split the *entire* dataset into a main (90%) and global test (10%)
# We stratify by 'label' to ensure all attack types are in the test set.
df_main_train, df_global_test = train_test_split(
    df_final,
    test_size=GLOBAL_TEST_SET_SIZE,
    random_state=42,
    stratify=df_final['label']
)

print(f"Main data for clients: {len(df_main_train)} samples")
print(f"Global test set: {len(df_global_test)} samples")

# --- 3. Process and Save the Global Test Set ---
# This file is for the SERVER person to evaluate the final model.
print("\nProcessing Global Test Set...")
y_str_global_test = df_global_test['label']
X_global_test = df_global_test.drop(columns=['label']).values
y_binary_global_test = np.where(y_str_global_test == NORMAL_LABEL, 0, 1)

global_test_file = os.path.join(OUTPUT_NPZ_PATH, "global_test_set.npz")
os.makedirs(OUTPUT_NPZ_PATH, exist_ok=True)
np.savez_compressed(global_test_file, X=X_global_test, y=y_binary_global_test)
print(f"  Saved Global Test Set to: {global_test_file}")

# --- 4. Apply Non-IID split on the *remaining* data ---
print("\nApplying Non-IID split to main training data...")

# We use *only* df_main_train to create the client data
df_client_1 = df_main_train[df_main_train['label'].isin(CLIENT_1_LABELS)]
df_client_2 = df_main_train[df_main_train['label'].isin(CLIENT_2_LABELS)]

print(f"Client 1 ('hospital') has {len(df_client_1)} total samples.")
print(f"  Labels: {df_client_1['label'].unique()}")
print(f"Client 2 ('factory') has {len(df_client_2)} total samples.")
print(f"  Labels: {df_client_2['label'].unique()}")

# 5. Define the helper function to process and save each client
def split_and_save_client(df_client, client_id, drive_path):
    print(f"\nProcessing {client_id}...")

    y_str = df_client['label']
    X = df_client.drop(columns=['label']).values

    y_binary = np.where(y_str == NORMAL_LABEL, 0, 1)
    print(f"  Found {np.sum(y_binary == 0)} normal samples and {np.sum(y_binary == 1)} attack samples.")

    # c. Create the 80/20 train/test split *for this client*
    #    (This is their *local* train/test set)
    X_train, X_test, y_train, y_test = train_test_split(
        X, y_binary,
        test_size=0.2,
        random_state=42,
        stratify=y_binary
    )
    print(f"  Split into {len(y_train)} local train and {len(y_test)} local test samples.")

    # d. Save to the .npz files your client code expects
    train_file = os.path.join(drive_path, f"client_{client_id}_train.npz")
    test_file = os.path.join(drive_path, f"client_{client_id}_test.npz")

    np.savez_compressed(train_file, X=X_train, y=y_train)
    np.savez_compressed(test_file, X=X_test, y=y_test)
    print(f"  Saved data to {train_file} and {test_file}")

    return X.shape[1]

# 6. Run the function for both clients
num_features_c1 = split_and_save_client(df_client_1, "hospital", OUTPUT_NPZ_PATH)
num_features_c2 = split_and_save_client(df_client_2, "factory", OUTPUT_NPZ_PATH)

# --- 7. FINAL, CRITICAL STEP ---
print("\n---------------------------------------------------------")
print("âœ…âœ…âœ… PREPROCESSING 100% COMPLETE! âœ…âœ…âœ…")
print("\nYour files are ready in Google Drive:")
print("  - client_hospital_train.npz (For Client 1 Training)")
print("  - client_hospital_test.npz  (For Client 1 *Local* Testing)")
print("  - client_factory_train.npz  (For Client 2 Training)")
print("  - client_factory_test.npz   (For Client 2 *Local* Testing)")
print("  - global_test_set.npz     (For the *Server* to test the final model)")
print("\nIMPORTANT: Go to your 'HERMES_Client/config.py' file and set:")
print(f"NUM_FEATURES = {num_features_c1}")
print("---------------------------------------------------------")

Loading final reduced dataset from: /content/drive/MyDrive/FL_Project/final_reduced_data.npz
Load complete.

Creating Global Test Set (Size: 10.0%) 
Main data for clients: 1456535 samples
Global test set: 161838 samples

Processing Global Test Set...
  Saved Global Test Set to: /content/drive/MyDrive/FL_Project/client_data/global_test_set.npz

Applying Non-IID split to main training data...
Client 1 ('hospital') has 491502 total samples.
  Labels: ['Spoofing' 'Benign']
Client 2 ('factory') has 1310050 total samples.
  Labels: ['DOS' 'DDOS' 'Recon' 'Benign' 'Mirai']

Processing hospital...
  Found 358468 normal samples and 133034 attack samples.
  Split into 393201 local train and 98301 local test samples.
  Saved data to /content/drive/MyDrive/FL_Project/client_data/client_hospital_train.npz and /content/drive/MyDrive/FL_Project/client_data/client_hospital_test.npz

Processing factory...
  Found 358468 normal samples and 951582 attack samples.
  Split into 1048040 local train and 26201