In [1]:
import os
import sys
import json
import time
import random
import numpy as np
import pandas as pd
import pickle
import threading
import argparse # For command-line argument parsing
from datetime import datetime

# Imports for machine learning and federated learning components
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score

# For TensorFlow models
import tensorflow as tf
from tensorflow.keras.models import Sequential, load_model, clone_model
from tensorflow.keras.layers import Dense, LSTM, Dropout
from tensorflow.keras.optimizers import Adam # Using Adam directly
from tensorflow.keras.callbacks import EarlyStopping

# --- ACTION REQUIRED: Mount your Google Drive if you plan to use it for persistent storage ---
# from google.colab import drive
# drive.mount('/content/drive')

# --- Define project base directory ---
# You can change this to a path in your Google Drive, e.g., "/content/drive/MyDrive/FL_IDS_Project"
# If left as is, it will create a directory in the current Colab session's temporary storage.
BASE_DIR = "/content/federated_ids_ai_project" # Default path for this federated learning script

if "/content/drive/" in BASE_DIR:
    print(f"✅ Using Google Drive path for BASE_DIR: {BASE_DIR}") # Ensured regular spaces
else:
    print(f"🔔 Using temporary Colab storage for BASE_DIR: {BASE_DIR}") # Ensured regular spaces
    print("   Data in this path will NOT persist across sessions unless you change it to a Drive path.") # Ensured regular spaces

# Define subdirectories for data, models, and federated learning artifacts
DATA_DIR = os.path.join(BASE_DIR, "data")
MODEL_DIR = os.path.join(BASE_DIR, "models")
FL_DIR = os.path.join(BASE_DIR, "federated_outputs")

# Create directories if they don't exist
for directory in [BASE_DIR, DATA_DIR, MODEL_DIR, FL_DIR]:
    os.makedirs(directory, exist_ok=True)

print(f"\nProject base directory: {BASE_DIR}")
print(f"Data directory (for client data if loaded from files): {DATA_DIR}")
print(f"Models directory (for saved FL models): {MODEL_DIR}")
print(f"Federated Learning outputs directory: {FL_DIR}")

print("\n✅ Section 1 (Federated Learning - Setup and Configuration) is ready.")

🔔 Using temporary Colab storage for BASE_DIR: /content/federated_ids_ai_project
   Data in this path will NOT persist across sessions unless you change it to a Drive path.

Project base directory: /content/federated_ids_ai_project
Data directory (for client data if loaded from files): /content/federated_ids_ai_project/data
Models directory (for saved FL models): /content/federated_ids_ai_project/models
Federated Learning outputs directory: /content/federated_ids_ai_project/federated_outputs

✅ Section 1 (Federated Learning - Setup and Configuration) is ready.


In [2]:
#@title 📱 Section 2: Federated Learning -
class FederatedClient:
    """
    Represents a client (edge device) in the federated learning system.
    Each client has its own local data and model.
    """

    def __init__(self, client_id, data=None, model=None):
        self.client_id = client_id
        self.data = data # Can be a DataFrame, or will be loaded by load_data
        self.model = model # Keras model instance, can be assigned later
        self.history = [] # To store training history for this client
        print(f"Client {self.client_id}: Initialized.")

    def load_data(self, data_path=None, partition_identifier=None, total_partitions=1):
        """
        Load data for this client.
        If data_path is provided, load from file and take a specific partition.
        If self.data was already a DataFrame (e.g., pre-partitioned and passed during init), it's used.

        Args:
            data_path (str, optional): Path to the full dataset file (e.g., .csv, .json, .jsonl).
            partition_identifier (int, optional): An identifier for this client's partition (e.g., client_id).
                                                  Used with total_partitions to select a data slice.
                                                  Assumes 0-based indexing for partitions.
            total_partitions (int, optional): Total number of partitions the dataset should be divided into.
        """
        if data_path and os.path.exists(data_path): # Check if data_path is provided and exists
            try:
                print(f"Client {self.client_id}: Attempting to load full dataset from {data_path}...")
                if data_path.endswith('.csv'):
                    full_data_df = pd.read_csv(data_path, low_memory=False)
                elif data_path.endswith('.json') or data_path.endswith('.jsonl'):
                    full_data_df = pd.read_json(data_path, lines=data_path.endswith('.jsonl'))
                else:
                    raise ValueError(f"Unsupported file format for data_path: {data_path}")

                if partition_identifier is not None and total_partitions > 0 and total_partitions <= len(full_data_df):
                    if not (0 <= partition_identifier < total_partitions):
                        raise ValueError(f"Client {self.client_id}: partition_identifier ({partition_identifier}) "
                                         f"must be between 0 and total_partitions-1 ({total_partitions-1}).")

                    num_samples_total = len(full_data_df)
                    # Integer division for roughly equal partitions
                    samples_per_partition = num_samples_total // total_partitions

                    start_idx = partition_identifier * samples_per_partition
                    # For the last partition, assign all remaining samples to ensure all data is used
                    end_idx = (partition_identifier + 1) * samples_per_partition if partition_identifier < total_partitions - 1 else num_samples_total

                    self.data = full_data_df.iloc[start_idx:end_idx].copy()
                    print(f"Client {self.client_id}: Loaded partition {partition_identifier + 1}/{total_partitions} "
                          f"(indices {start_idx}-{end_idx-1}, {len(self.data)} samples) from {data_path}")
                else: # Use the whole dataset if no valid partitioning info
                    self.data = full_data_df
                    print(f"Client {self.client_id}: Loaded {len(self.data)} samples (full dataset) from {data_path}")
                return True

            except Exception as e:
                print(f"Client {self.client_id}: Error loading data from path '{data_path}': {e}")
                self.data = None
                return False

        elif isinstance(self.data, pd.DataFrame): # Data was passed during __init__
            print(f"Client {self.client_id}: Using {len(self.data)} samples from pre-loaded DataFrame.")
            return True
        elif self.data is not None: # Data was passed but not a DataFrame
             try:
                self.data = pd.DataFrame(self.data) # Try to convert
                print(f"Client {self.client_id}: Converted pre-loaded data to DataFrame, {len(self.data)} samples.")
                return True
             except Exception as e:
                print(f"Client {self.client_id}: Could not convert pre-loaded data of type {type(self.data)} to DataFrame: {e}")
                self.data = None
                return False
        else: # No data_path and self.data is None
            print(f"Client {self.client_id}: No data path provided and no pre-loaded data.")
            self.data = None
            return False

    def preprocess_data(self, feature_columns=None, target_column=None, test_size=0.2, reshape_for_lstm=False):
        """
        Preprocess the client's local data.
        This includes feature selection, label encoding, scaling, and splitting.
        Returns: (X_train_scaled, X_test_scaled, y_train, y_test, preproc_objects) or None
        """
        if self.data is None or self.data.empty:
            print(f"Client {self.client_id}: No data loaded for preprocessing.")
            return None

        try:
            df_processed = self.data.copy()
            df_processed.replace([np.inf, -np.inf], np.nan, inplace=True)
            df_processed.dropna(inplace=True)

            if df_processed.empty:
                print(f"Client {self.client_id}: Data became empty after handling NaNs during preprocessing.")
                return None

            if target_column is None:
                target_column = df_processed.columns[-1]
                print(f"Client {self.client_id}: Assuming last column '{target_column}' as target.")

            if target_column not in df_processed.columns:
                print(f"Client {self.client_id}: Target column '{target_column}' not found.")
                return None

            y = df_processed[target_column]

            if feature_columns is None:
                X = df_processed.drop(columns=[target_column])
                print(f"Client {self.client_id}: Using all columns except '{target_column}' as features.")
            else:
                missing_cols = [col for col in feature_columns if col not in df_processed.columns]
                if missing_cols:
                    print(f"Client {self.client_id}: Missing specified feature columns: {missing_cols}")
                    return None
                X = df_processed[feature_columns]

            used_feature_columns = X.columns.tolist()

            # Ensure X contains only numeric data for scaling, or handle non-numeric appropriately
            X_numeric = X.select_dtypes(include=np.number)
            if X_numeric.shape[1] < X.shape[1]:
                non_numeric_cols = X.select_dtypes(exclude=np.number).columns.tolist()
                print(f"Client {self.client_id}: Warning - Dropping non-numeric columns from features: {non_numeric_cols}. "
                      "These should be encoded if they are intended for use.")
            X = X_numeric # Proceed with only numeric features

            if X.empty:
                print(f"Client {self.client_id}: No numeric features available after selection/filtering.")
                return None

            label_encoder = LabelEncoder()
            y_encoded = label_encoder.fit_transform(y)
            num_classes = len(label_encoder.classes_)

            # Stratification requires at least 2 samples per class in the smallest split (test set)
            # or generally at least n_splits samples per class.
            min_samples_per_class_for_stratify = 2 # A common minimum for train_test_split with test_size
            class_counts = pd.Series(y_encoded).value_counts()

            can_stratify = num_classes >= 2 and all(count >= min_samples_per_class_for_stratify for count in class_counts)

            if can_stratify:
                X_train, X_test, y_train, y_test = train_test_split(
                    X, y_encoded, test_size=test_size, random_state=self.client_id, stratify=y_encoded
                )
            else:
                if num_classes < 2:
                    print(f"Client {self.client_id}: Only {num_classes} unique class(es) in target. Cannot stratify or train effectively.")
                    # If only one class, training is not meaningful
                    return None
                print(f"Client {self.client_id}: Cannot stratify due to too few samples in some classes. Splitting without stratification.")
                X_train, X_test, y_train, y_test = train_test_split(
                    X, y_encoded, test_size=test_size, random_state=self.client_id
                )

            scaler = StandardScaler()
            X_train_scaled = scaler.fit_transform(X_train)
            X_test_scaled = scaler.transform(X_test)

            if reshape_for_lstm:
                X_train_scaled = X_train_scaled.reshape(X_train_scaled.shape[0], 1, X_train_scaled.shape[1])
                X_test_scaled = X_test_scaled.reshape(X_test_scaled.shape[0], 1, X_test_scaled.shape[1])

            preproc_objects = {
                'scaler': scaler,
                'label_encoder': label_encoder,
                'used_feature_columns': used_feature_columns,
                'target_column_name': target_column,
                'num_classes': num_classes # Number of unique classes found by LabelEncoder
            }

            print(f"Client {self.client_id}: Data preprocessed. Train shape: {X_train_scaled.shape}, Test shape: {X_test_scaled.shape}")
            return X_train_scaled, X_test_scaled, y_train, y_test, preproc_objects

        except Exception as e:
            print(f"Client {self.client_id}: Error during data preprocessing: {e}")
            import traceback
            traceback.print_exc()
            return None

# --- Test Block for FederatedClient Data Handling (Optional) ---
# This test runs only if the cell is executed directly in Colab
if __name__ == "__main__" and 'google.colab' in sys.modules:
    print("\n--- Testing FederatedClient Data Handling ---")

    # Create dummy data for client
    dummy_data_for_client = pd.DataFrame({
        'feature1': np.random.rand(100),
        'feature2': np.random.rand(100),
        'feature3_cat': np.random.choice(['A', 'B', 'C'], 100), # Categorical feature
        'feature4_num': np.random.randint(0, 5, 100),
        'label': np.random.choice(['Normal', 'Attack_Type1', 'Attack_Type2'], 100) # Multi-class target
    })

    # DATA_DIR should be defined from Section 1
    if 'DATA_DIR' in globals() and os.path.exists(DATA_DIR):
        dummy_csv_path = os.path.join(DATA_DIR, "client_dummy_data.csv")
        dummy_data_for_client.to_csv(dummy_csv_path, index=False)
        print(f"Dummy data saved to {dummy_csv_path}")

        # Test Case 1: Load from path with partitioning
        print("\n-- Test Case 1: Load from path with partitioning --")
        client1 = FederatedClient(client_id=1)
        # Load 1st partition out of 2
        client1.load_data(data_path=dummy_csv_path, partition_identifier=0, total_partitions=2)
        if client1.data is not None:
            print(f"Client 1 data shape after load & partition: {client1.data.shape}")
            # Preprocess, assuming last col 'label' is target, and we want LSTM shape
            # Explicitly select numeric features for this dummy data, excluding 'feature3_cat' for now
            client1_preproc_result = client1.preprocess_data(
                feature_columns=['feature1', 'feature2', 'feature4_num'],
                target_column='label',
                reshape_for_lstm=True
            )
            if client1_preproc_result:
                 print(f"Client 1 preprocessed train data shape (LSTM): {client1_preproc_result[0].shape}")
                 print(f"Client 1 preprocessed objects: {client1_preproc_result[4].keys()}")
                 print(f"Client 1 num_classes from preproc: {client1_preproc_result[4]['num_classes']}")


        # Test Case 2: Load with pre-loaded data (full dummy data)
        print("\n-- Test Case 2: Load with pre-loaded data --")
        client2 = FederatedClient(client_id=2, data=dummy_data_for_client.copy()) # Pass a copy
        client2.load_data() # Should use the pre-loaded data, already a DataFrame
        if client2.data is not None:
            print(f"Client 2 data shape from pre-loaded: {client2.data.shape}")
            # Preprocess for MLP, using default feature selection (all numeric but target)
            client2_preproc_result = client2.preprocess_data(target_column='label', reshape_for_lstm=False)
            if client2_preproc_result:
                 print(f"Client 2 preprocessed train data shape (MLP): {client2_preproc_result[0].shape}")
                 print(f"Client 2 preprocessed objects keys: {list(client2_preproc_result[4].keys())}")
                 print(f"Client 2 target classes by LabelEncoder: {client2_preproc_result[4]['label_encoder'].classes_}")
                 print(f"Client 2 num_classes from preproc: {client2_preproc_result[4]['num_classes']}")

    else:
        print("⚠️ Skipping FederatedClient data handling test as DATA_DIR is not defined " \
              "or accessible (Run Section 1 of Federated Learning script).")
    print("\n--- End of FederatedClient Data Handling Test ---")

print("\n✅ Section 2 (FederatedClient Class - Part 1: Init & Data) is ready.")


--- Testing FederatedClient Data Handling ---
Dummy data saved to /content/federated_ids_ai_project/data/client_dummy_data.csv

-- Test Case 1: Load from path with partitioning --
Client 1: Initialized.
Client 1: Attempting to load full dataset from /content/federated_ids_ai_project/data/client_dummy_data.csv...
Client 1: Loaded partition 1/2 (indices 0-49, 50 samples) from /content/federated_ids_ai_project/data/client_dummy_data.csv
Client 1 data shape after load & partition: (50, 5)
Client 1: Data preprocessed. Train shape: (40, 1, 3), Test shape: (10, 1, 3)
Client 1 preprocessed train data shape (LSTM): (40, 1, 3)
Client 1 preprocessed objects: dict_keys(['scaler', 'label_encoder', 'used_feature_columns', 'target_column_name', 'num_classes'])
Client 1 num_classes from preproc: 3

-- Test Case 2: Load with pre-loaded data --
Client 2: Initialized.
Client 2: Using 100 samples from pre-loaded DataFrame.
Client 2 data shape from pre-loaded: (100, 5)
Client 2: Using all columns except '

In [3]:

#@title 🧠 Section 3: Federated Learning - FederatedClient Class (Part 2: Model Building & Training)
# Imports from Section 1 should still be in effect.
# Ensure TensorFlow/Keras components are available.
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler, LabelEncoder
import tensorflow as tf
from tensorflow.keras.models import Sequential, clone_model
from tensorflow.keras.layers import Dense, LSTM, Dropout
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import EarlyStopping
from datetime import datetime

class FederatedClient:
    """
    Represents a client (edge device) in the federated learning system.
    Each client has its own local data and model.
    (Includes methods from Part 1 and adds new methods from Part 2)
    """

    def __init__(self, client_id, data=None, model=None):
        self.client_id = client_id
        self.data = data
        self.model = model
        self.history = []
        # print(f"Client {self.client_id}: Initialized.") # Keep this less verbose for multiple clients

    def load_data(self, data_path=None, partition_identifier=None, total_partitions=1):
        if data_path and os.path.exists(data_path):
            try:
                # print(f"Client {self.client_id}: Attempting to load full dataset from {data_path}...")
                if data_path.endswith('.csv'): full_data_df = pd.read_csv(data_path, low_memory=False)
                elif data_path.endswith(('.json', '.jsonl')): full_data_df = pd.read_json(data_path, lines=data_path.endswith('.jsonl'))
                else: raise ValueError(f"Unsupported file format: {data_path}")

                if partition_identifier is not None and total_partitions > 0 and total_partitions <= len(full_data_df):
                    if not (0 <= partition_identifier < total_partitions): raise ValueError("partition_identifier out of range.")
                    num_samples_total = len(full_data_df)
                    samples_per_partition = num_samples_total // total_partitions
                    start_idx = partition_identifier * samples_per_partition
                    end_idx = (partition_identifier + 1) * samples_per_partition if partition_identifier < total_partitions - 1 else num_samples_total
                    self.data = full_data_df.iloc[start_idx:end_idx].copy()
                    # print(f"Client {self.client_id}: Loaded partition {partition_identifier + 1}/{total_partitions} ({len(self.data)} samples)")
                else:
                    self.data = full_data_df
                    # print(f"Client {self.client_id}: Loaded {len(self.data)} samples (full dataset).")
                return True
            except Exception as e:
                print(f"Client {self.client_id}: Error loading data from '{data_path}': {e}"); self.data = None; return False

In [4]:

#@title 🧠 Section 3: Federated Learning - FederatedClient Class (Part 2: Model Building & Training)
# Imports from Section 1 should still be in effect.
# Class FederatedClient from Section 2 will be extended here.

class FederatedClient:
    """
    Represents a client (edge device) in the federated learning system.
    Each client has its own local data and model.
    (Includes methods from Part 1 and adds new methods from Part 2)
    """

    def __init__(self, client_id, data=None, model=None):
        self.client_id = client_id
        self.data = data
        self.model = model
        self.history = []
        # print(f"Client {self.client_id}: Initialized.") # Kept less verbose

    def load_data(self, data_path=None, partition_identifier=None, total_partitions=1):
        if data_path and os.path.exists(data_path):
            try:
                # print(f"Client {self.client_id}: Loading from {data_path}...") # Verbose
                if data_path.endswith('.csv'): full_data_df = pd.read_csv(data_path, low_memory=False)
                elif data_path.endswith(('.json', '.jsonl')): full_data_df = pd.read_json(data_path, lines=data_path.endswith('.jsonl'))
                else: raise ValueError(f"Unsupported file format: {data_path}")

                if partition_identifier is not None and total_partitions > 0 and total_partitions <= len(full_data_df):
                    if not (0 <= partition_identifier < total_partitions): raise ValueError("partition_identifier out of range.")
                    num_samples_total = len(full_data_df)
                    samples_per_partition = num_samples_total // total_partitions
                    start_idx = partition_identifier * samples_per_partition
                    end_idx = (partition_identifier + 1) * samples_per_partition if partition_identifier < total_partitions - 1 else num_samples_total
                    self.data = full_data_df.iloc[start_idx:end_idx].copy()
                    # print(f"Client {self.client_id}: Loaded partition {partition_identifier + 1}/{total_partitions} ({len(self.data)} samples)")
                else:
                    self.data = full_data_df
                    # print(f"Client {self.client_id}: Loaded {len(self.data)} samples (full dataset).")
                return True
            except Exception as e:
                print(f"Client {self.client_id}: Error loading data from '{data_path}': {e}"); self.data = None; return False
        elif isinstance(self.data, pd.DataFrame): return True
        elif self.data is not None:
             try: self.data = pd.DataFrame(self.data); return True
             except Exception as e: print(f"Client {self.client_id}: Could not convert pre-loaded data: {e}"); self.data=None; return False
        else: self.data = None; return False

    def preprocess_data(self, feature_columns=None, target_column=None, test_size=0.2, reshape_for_lstm=False):
        if self.data is None or self.data.empty: return None
        try:
            df_processed = self.data.copy()
            df_processed.replace([np.inf, -np.inf], np.nan, inplace=True)
            df_processed.dropna(inplace=True)
            if df_processed.empty: return None

            if target_column is None: target_column = df_processed.columns[-1]
            if target_column not in df_processed.columns: return None
            y = df_processed[target_column]

            if feature_columns is None: X = df_processed.drop(columns=[target_column])
            else:
                missing_cols = [col for col in feature_columns if col not in df_processed.columns]
                if missing_cols: print(f"Client {self.client_id}: Missing features: {missing_cols}"); return None
                X = df_processed[feature_columns]

            used_feature_columns = X.columns.tolist()
            X_numeric = X.select_dtypes(include=np.number)
            if X_numeric.shape[1] < X.shape[1]: pass # Non-numeric will be dropped
            X = X_numeric
            if X.empty: return None

            label_encoder = LabelEncoder()
            y_encoded = label_encoder.fit_transform(y)
            num_classes = len(label_encoder.classes_)

            min_samples_per_class_for_stratify = 2
            class_counts = pd.Series(y_encoded).value_counts()
            can_stratify = num_classes >= 2 and all(count >= min_samples_per_class_for_stratify for count in class_counts)

            stratify_option = y_encoded if can_stratify else None
            try:
                X_train, X_test, y_train, y_test = train_test_split(
                    X, y_encoded, test_size=test_size, random_state=self.client_id, stratify=stratify_option
                )
            except ValueError: # Fallback if stratification fails for any reason
                X_train, X_test, y_train, y_test = train_test_split(
                    X, y_encoded, test_size=test_size, random_state=self.client_id
                )

            scaler = StandardScaler()
            X_train_scaled = scaler.fit_transform(X_train)
            X_test_scaled = scaler.transform(X_test)

            if reshape_for_lstm:
                X_train_scaled = X_train_scaled.reshape(X_train_scaled.shape[0], 1, X_train_scaled.shape[1])
                X_test_scaled = X_test_scaled.reshape(X_test_scaled.shape[0], 1, X_test_scaled.shape[1])

            preproc_objects = {
                'scaler': scaler, 'label_encoder': label_encoder,
                'used_feature_columns': used_feature_columns,
                'target_column_name': target_column, 'num_classes': num_classes
            }
            # print(f"Client {self.client_id}: Data preprocessed. Train: {X_train_scaled.shape}, Test: {X_test_scaled.shape}")
            return X_train_scaled, X_test_scaled, y_train, y_test, preproc_objects
        except Exception as e:
            print(f"Client {self.client_id}: Error preprocessing: {e}"); import traceback; traceback.print_exc(); return None

    # --- New methods for Section 3 ---
    def build_lstm_model(self, input_shape, num_classes):
        """Build an LSTM model for sequence classification."""
        model = Sequential([
            LSTM(64, input_shape=input_shape, return_sequences=True),
            Dropout(0.2),
            LSTM(32), # return_sequences=False by default for the last LSTM layer before Dense
            Dropout(0.2),
            Dense(16, activation='relu'), # An intermediate dense layer
            Dense(num_classes, activation='softmax') # Softmax for multi-class classification
        ])
        model.compile(
            optimizer=Adam(learning_rate=0.001),
            loss='sparse_categorical_crossentropy', # Use for integer labels
            metrics=['accuracy']
        )
        return model

    def build_mlp_model(self, input_shape, num_classes):
        """Build a simple MLP model for classification."""
        model = Sequential([
            Dense(128, activation='relu', input_shape=input_shape),
            Dropout(0.3),
            Dense(64, activation='relu'),
            Dropout(0.2),
            Dense(32, activation='relu'),
            Dense(num_classes, activation='softmax') # Softmax for multi-class
        ])
        model.compile(
            optimizer=Adam(learning_rate=0.001),
            loss='sparse_categorical_crossentropy', # Use for integer labels
            metrics=['accuracy']
        )
        return model

    def train_local_model(self, model_type='mlp', epochs=10, batch_size=32, verbose=0, feature_columns=None, target_column=None):
        """Train the local model on the client's data."""
        if self.data is None:
            print(f"Client {self.client_id}: No data loaded for training.")
            return None

        preproc_result = self.preprocess_data(
            feature_columns=feature_columns,
            target_column=target_column,
            reshape_for_lstm=(model_type.lower() == 'lstm')
        )
        if preproc_result is None:
            print(f"Client {self.client_id}: Preprocessing failed for training.")
            return None

        X_train, X_test, y_train, y_test, preproc_objects = preproc_result
        num_classes = preproc_objects['num_classes']

        # Ensure there's data to train on
        if X_train.shape[0] == 0:
            print(f"Client {self.client_id}: No training data after preprocessing (X_train is empty).")
            return None

        if model_type.lower() == 'lstm':
            model_input_shape = (X_train.shape[1], X_train.shape[2]) # Should be (1, num_features)
            self.model = self.build_lstm_model(model_input_shape, num_classes)
        else: # MLP
            model_input_shape = (X_train.shape[1],) # (num_features,)
            self.model = self.build_mlp_model(model_input_shape, num_classes)

        # print(f"Client {self.client_id}: Training {model_type.upper()} model. Input shape: {model_input_shape}, Classes: {num_classes}.")

        early_stopping = EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True, verbose=0)

        history_obj = self.model.fit(
            X_train, y_train,
            epochs=epochs,
            batch_size=batch_size,
            validation_data=(X_test, y_test),
            callbacks=[early_stopping],
            verbose=verbose # Controlled by argument
        )

        loss, accuracy = self.model.evaluate(X_test, y_test, verbose=0)

        train_result_summary = {
            'client_id': self.client_id, 'model_type': model_type,
            'epochs_run': len(history_obj.history['loss']),
            'batch_size': batch_size,
            'history': {k: [round(float(val), 4) for val in v] for k, v in history_obj.history.items()}, # Ensure float for JSON
            'test_loss': float(loss), 'test_accuracy': float(accuracy),
            'timestamp': datetime.now().isoformat()
        }
        self.history.append(train_result_summary)

        print(f"Client {self.client_id}: Training completed. Test accuracy: {accuracy:.4f}")
        return {
            'model_weights': self.model.get_weights(),
            'num_samples': X_train.shape[0] + X_test.shape[0], # Total samples processed by client
            'preprocessing_details': preproc_objects,
            'training_summary': train_result_summary
        }

# --- Test Block for FederatedClient Model Training (Optional) ---
if __name__ == "__main__" and 'google.colab' in sys.modules:
    print("\n--- Testing FederatedClient Model Training ---")
    if 'DATA_DIR' in globals() and os.path.exists(DATA_DIR):
        dummy_csv_path = os.path.join(DATA_DIR, "client_dummy_data.csv")
        if not os.path.exists(dummy_csv_path):
            dummy_data_for_client = pd.DataFrame({
                'feature1': np.random.rand(100), 'feature2': np.random.rand(100),
                'feature3_cat': np.random.choice(['A', 'B', 'C'], 100), # Will be dropped by current preprocess
                'feature4_num': np.random.randint(0, 5, 100),
                'label': np.random.choice(['Normal', 'Attack_Type1', 'Attack_Type2'], 100)
            })
            dummy_data_for_client.to_csv(dummy_csv_path, index=False)
            print(f"Re-created dummy data at {dummy_csv_path} for test.")

        client3 = FederatedClient(client_id=3) # New client for this test
        # Load full dummy data for client3
        loaded = client3.load_data(data_path=dummy_csv_path, partition_identifier=0, total_partitions=1)

        if loaded and client3.data is not None:
            print(f"\n-- Training MLP model on Client 3 (Epochs=3, Verbose=1) --")
            numeric_features_for_dummy = ['feature1', 'feature2', 'feature4_num']
            mlp_train_results = client3.train_local_model(
                model_type='mlp', epochs=3, verbose=1,
                feature_columns=numeric_features_for_dummy, target_column='label'
            )
            if mlp_train_results:
                print(f"Client 3 MLP training summary: Test Accuracy {mlp_train_results['training_summary']['test_accuracy']:.4f}")

            client3.model = None # Reset model to ensure LSTM builds fresh
            print(f"\n-- Training LSTM model on Client 3 (Epochs=3, Verbose=1) --")
            lstm_train_results = client3.train_local_model(
                model_type='lstm', epochs=3, verbose=1,
                feature_columns=numeric_features_for_dummy, target_column='label'
            )
            if lstm_train_results:
                print(f"Client 3 LSTM training summary: Test Accuracy {lstm_train_results['training_summary']['test_accuracy']:.4f}")
        else:
            print("Client 3 could not load data for model training test.")
    else:
        print("⚠️ Skipping FederatedClient model training test as DATA_DIR not found (Run Section 1).")
    print("\n--- End of FederatedClient Model Training Test ---")

print("\n✅ Section 3 (FederatedClient Class - Part 2: Model Building & Training) is ready.")


--- Testing FederatedClient Model Training ---

-- Training MLP model on Client 3 (Epochs=3, Verbose=1) --


  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


Epoch 1/3
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m7s[0m 384ms/step - accuracy: 0.3500 - loss: 1.1244 - val_accuracy: 0.2500 - val_loss: 1.0937
Epoch 2/3
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 191ms/step - accuracy: 0.3461 - loss: 1.1167 - val_accuracy: 0.1500 - val_loss: 1.0898
Epoch 3/3
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 167ms/step - accuracy: 0.3805 - loss: 1.0851 - val_accuracy: 0.2000 - val_loss: 1.0874
Client 3: Training completed. Test accuracy: 0.2000
Client 3 MLP training summary: Test Accuracy 0.2000

-- Training LSTM model on Client 3 (Epochs=3, Verbose=1) --


  super().__init__(**kwargs)


Epoch 1/3
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8s[0m 292ms/step - accuracy: 0.3758 - loss: 1.0980 - val_accuracy: 0.3500 - val_loss: 1.0974
Epoch 2/3
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 39ms/step - accuracy: 0.4711 - loss: 1.0950 - val_accuracy: 0.3000 - val_loss: 1.0968
Epoch 3/3
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 39ms/step - accuracy: 0.4453 - loss: 1.0939 - val_accuracy: 0.4500 - val_loss: 1.0963
Client 3: Training completed. Test accuracy: 0.4500
Client 3 LSTM training summary: Test Accuracy 0.4500

--- End of FederatedClient Model Training Test ---

✅ Section 3 (FederatedClient Class - Part 2: Model Building & Training) is ready.


In [5]:
#@title 💾 Section 4: Federated Learning - FederatedClient Class (Part 3: Model Management & Evaluation)


class FederatedClient:
    """
    Represents a client (edge device) in the federated learning system.
    Each client has its own local data and model.
    (Includes methods from Part 1 & 2, and adds new methods from Part 3)
    """

    def __init__(self, client_id, data=None, model=None):
        self.client_id = client_id
        self.data = data
        self.model = model
        self.history = []
        # print(f"Client {self.client_id}: Initialized.") # Verbosity controlled

    def load_data(self, data_path=None, partition_identifier=None, total_partitions=1):
        # This method is from Section 2, included for completeness of the class definition
        if data_path and os.path.exists(data_path):
            try:
                if data_path.endswith('.csv'): full_data_df = pd.read_csv(data_path, low_memory=False)
                elif data_path.endswith(('.json', '.jsonl')): full_data_df = pd.read_json(data_path, lines=data_path.endswith('.jsonl'))
                else: raise ValueError(f"Unsupported file format: {data_path}")

                if partition_identifier is not None and total_partitions > 0 and total_partitions <= len(full_data_df):
                    if not (0 <= partition_identifier < total_partitions): raise ValueError("partition_identifier out of range.")
                    num_samples_total = len(full_data_df)
                    samples_per_partition = num_samples_total // total_partitions
                    start_idx = partition_identifier * samples_per_partition
                    end_idx = (partition_identifier + 1) * samples_per_partition if partition_identifier < total_partitions - 1 else num_samples_total
                    self.data = full_data_df.iloc[start_idx:end_idx].copy()
                else:
                    self.data = full_data_df
                return True
            except Exception as e:
                print(f"Client {self.client_id}: Error loading data from '{data_path}': {e}"); self.data = None; return False
        elif isinstance(self.data, pd.DataFrame): return True
        elif self.data is not None:
             try: self.data = pd.DataFrame(self.data); return True
             except Exception as e: print(f"Client {self.client_id}: Could not convert pre-loaded data: {e}"); self.data=None; return False
        else: self.data = None; return False

    def preprocess_data(self, feature_columns=None, target_column=None, test_size=0.2, reshape_for_lstm=False):
        # This method is from Section 2, included for completeness
        if self.data is None or self.data.empty: return None
        try:
            df_processed = self.data.copy()
            df_processed.replace([np.inf, -np.inf], np.nan, inplace=True)
            df_processed.dropna(inplace=True)
            if df_processed.empty: return None

            if target_column is None: target_column = df_processed.columns[-1]
            if target_column not in df_processed.columns: return None
            y = df_processed[target_column]

            if feature_columns is None: X = df_processed.drop(columns=[target_column])
            else:
                missing_cols = [col for col in feature_columns if col not in df_processed.columns]
                if missing_cols: print(f"Client {self.client_id}: Missing features: {missing_cols}"); return None
                X = df_processed[feature_columns]

            used_feature_columns = X.columns.tolist()
            X_numeric = X.select_dtypes(include=np.number)
            if X_numeric.shape[1] < X.shape[1]: pass
            X = X_numeric
            if X.empty: return None

            label_encoder = LabelEncoder()
            y_encoded = label_encoder.fit_transform(y)
            num_classes = len(label_encoder.classes_)

            min_samples_per_class_for_stratify = 2
            class_counts = pd.Series(y_encoded).value_counts()
            can_stratify = num_classes >= 2 and all(count >= min_samples_per_class_for_stratify for count in class_counts if count >0)

            stratify_option = y_encoded if can_stratify else None
            if len(X) < 2 :
                 print(f"Client {self.client_id}: Not enough samples ({len(X)}) to perform train/test split."); return None

            try:
                X_train, X_test, y_train, y_test = train_test_split(
                    X, y_encoded, test_size=test_size, random_state=self.client_id, stratify=stratify_option
                )
            except ValueError:
                X_train, X_test, y_train, y_test = train_test_split(
                    X, y_encoded, test_size=test_size, random_state=self.client_id
                )

            scaler = StandardScaler()
            X_train_scaled = scaler.fit_transform(X_train)
            if len(X_test) > 0:
                 X_test_scaled = scaler.transform(X_test)
            else:
                 X_test_scaled = np.array([])

            if reshape_for_lstm:
                X_train_scaled = X_train_scaled.reshape(X_train_scaled.shape[0], 1, X_train_scaled.shape[1])
                if len(X_test_scaled) > 0:
                     X_test_scaled = X_test_scaled.reshape(X_test_scaled.shape[0], 1, X_test_scaled.shape[1])

            preproc_objects = {'scaler': scaler, 'label_encoder': label_encoder, 'used_feature_columns': used_feature_columns, 'target_column_name': target_column, 'num_classes': num_classes}
            return X_train_scaled, X_test_scaled, y_train, y_test, preproc_objects
        except Exception as e:
            print(f"Client {self.client_id}: Error preprocessing: {e}"); import traceback; traceback.print_exc(); return None

    def build_lstm_model(self, input_shape, num_classes):
        # This method is from Section 3, included for completeness
        model = Sequential([
            LSTM(64, input_shape=input_shape, return_sequences=True), Dropout(0.2),
            LSTM(32), Dropout(0.2), Dense(16, activation='relu'),
            Dense(num_classes, activation='softmax')])
        model.compile(optimizer=Adam(learning_rate=0.001), loss='sparse_categorical_crossentropy', metrics=['accuracy'])
        return model

    def build_mlp_model(self, input_shape, num_classes):
        # This method is from Section 3, included for completeness
        model = Sequential([
            Dense(128, activation='relu', input_shape=input_shape), Dropout(0.3),
            Dense(64, activation='relu'), Dropout(0.2), Dense(32, activation='relu'),
            Dense(num_classes, activation='softmax')])
        model.compile(optimizer=Adam(learning_rate=0.001), loss='sparse_categorical_crossentropy', metrics=['accuracy'])
        return model

    def train_local_model(self, model_type='mlp', epochs=10, batch_size=32, verbose=0, feature_columns=None, target_column=None):
        # This method is from Section 3, included for completeness
        if self.data is None: return None
        preproc_result = self.preprocess_data(feature_columns, target_column, reshape_for_lstm=(model_type.lower() == 'lstm'))
        if preproc_result is None: return None
        X_train, X_test, y_train, y_test, preproc_objects = preproc_result
        num_classes = preproc_objects['num_classes']
        if X_train.shape[0] == 0: return None

        if model_type.lower() == 'lstm':
            model_input_shape = (X_train.shape[1], X_train.shape[2])
            self.model = self.build_lstm_model(model_input_shape, num_classes)
        else:
            model_input_shape = (X_train.shape[1],)
            self.model = self.build_mlp_model(model_input_shape, num_classes)

        validation_data_tuple = (X_test, y_test) if len(X_test) > 0 else None
        early_stopping = EarlyStopping(monitor='val_loss' if validation_data_tuple else 'loss', patience=5, restore_best_weights=True, verbose=0)
        history_obj = self.model.fit(X_train, y_train, epochs=epochs, batch_size=batch_size, validation_data=validation_data_tuple, callbacks=[early_stopping], verbose=verbose)

        test_loss, test_accuracy = (0.0,0.0)
        if validation_data_tuple :
             test_loss, test_accuracy = self.model.evaluate(X_test, y_test, verbose=0)

        train_summary = {'client_id':self.client_id, 'model_type':model_type, 'epochs_run':len(history_obj.history['loss']), 'batch_size':batch_size, 'history':{k:[round(float(v),4) for v in val_list] for k,val_list in history_obj.history.items()}, 'test_loss':float(test_loss), 'test_accuracy':float(test_accuracy), 'timestamp':datetime.now().isoformat()}
        self.history.append(train_summary)
        print(f"Client {self.client_id}: Training completed. Test accuracy: {test_accuracy:.4f}")
        return {'model_weights': self.model.get_weights(), 'num_samples': X_train.shape[0] + (X_test.shape[0] if len(X_test)>0 else 0) , 'preprocessing_details': preproc_objects, 'training_summary': train_summary}

    # --- New methods for Section 4 ---
    def update_model(self, model_weights):
        """Update the local model with new weights."""
        if self.model is None:
            print(f"Client {self.client_id}: No local model initialized to update.")
            return False
        try:
            self.model.set_weights(model_weights)
            return True
        except Exception as e:
            print(f"Client {self.client_id}: Error updating model weights: {e}")
            return False

    def get_model_weights(self):
        """Get the weights of the local model."""
        if self.model is None:
            print(f"Client {self.client_id}: No model available to get weights from.")
            return None
        return self.model.get_weights()

    def evaluate_model(self, X_test_external=None, y_test_external=None, preproc_objects_external=None):
        """Evaluate the local model."""
        if self.model is None:
            print(f"Client {self.client_id}: No model to evaluate.")
            return None

        X_test_eval, y_test_eval = None, None

        if X_test_external is not None and y_test_external is not None:
            X_test_eval, y_test_eval = X_test_external, y_test_external
            if self.model.layers[0].__class__.__name__ == 'LSTM' and len(X_test_eval.shape) == 2: # Check if model is LSTM by first layer type
                 X_test_eval = X_test_eval.reshape(X_test_eval.shape[0], 1, X_test_eval.shape[1])
        else:
            if not self.history:
                print(f"Client {self.client_id}: No training history and no external test data provided.")
                return None
            last_train_summary = self.history[-1]
            print(f"Client {self.client_id}: Reporting test accuracy from last local training: {last_train_summary['test_accuracy']:.4f}")
            return {
                'client_id': self.client_id, 'loss': last_train_summary['test_loss'],
                'accuracy': last_train_summary['test_accuracy'], 'timestamp': datetime.now().isoformat(),
                'note': 'Metrics from last local training validation split.'}

        if X_test_eval is None or y_test_eval is None or (isinstance(X_test_eval, np.ndarray) and X_test_eval.size == 0):
            print(f"Client {self.client_id}: No test data available for evaluation.")
            return None

        loss, accuracy = self.model.evaluate(X_test_eval, y_test_eval, verbose=0)
        y_pred_probs = self.model.predict(X_test_eval, verbose=0)
        y_pred_classes = np.argmax(y_pred_probs, axis=1)

        report_dict = {}; conf_matrix_list = []
        try:
            unique_true_labels = np.unique(y_test_eval)
            unique_pred_labels = np.unique(y_pred_classes)
            all_labels = np.union1d(unique_true_labels, unique_pred_labels) # Get all unique labels present in either

            report_dict = classification_report(y_test_eval, y_pred_classes, labels=all_labels if len(all_labels)>0 else None, output_dict=True, zero_division=0)
            conf_matrix_list = confusion_matrix(y_test_eval, y_pred_classes, labels=all_labels if len(all_labels)>0 else None).tolist()
        except Exception as e:
            print(f"Client {self.client_id}: Could not generate full classification report/confusion matrix: {e}")

        eval_result = {
            'client_id': self.client_id, 'loss': float(loss), 'accuracy': float(accuracy),
            'classification_report': report_dict, 'confusion_matrix': conf_matrix_list,
            'timestamp': datetime.now().isoformat()}
        return eval_result

    def save_model(self, path=None):
        """Save the local model to a file."""
        if self.model is None: print(f"Client {self.client_id}: No model to save."); return None
        if path is None:
            path = os.path.join(FL_DIR if 'FL_DIR' in globals() else ".", f"client_{self.client_id}_model.h5")
        try:
            self.model.save(path)
            print(f"Client {self.client_id}: Model saved to {path}")
            return path
        except Exception as e: print(f"Client {self.client_id}: Error saving model: {e}"); return None

    def load_model(self, path):
        """Load a model from a file. Overwrites existing self.model."""
        try:
            self.model = tf.keras.models.load_model(path)
            print(f"Client {self.client_id}: Model loaded from {path}")
            return True
        except Exception as e: print(f"Client {self.client_id}: Error loading model from {path}: {e}"); return False

# --- Test Block for FederatedClient Model Management (Optional) ---
# Corrected Indentation and typo for this block
if __name__ == "__main__" and 'google.colab' in sys.modules:
    print("\n--- Testing FederatedClient Model Management ---")
    if 'DATA_DIR' in globals() and os.path.exists(DATA_DIR) and \
       'FL_DIR' in globals() and os.path.exists(FL_DIR):

        dummy_csv_path = os.path.join(DATA_DIR, "client_dummy_data.csv")
        if not os.path.exists(dummy_csv_path):
            pd.DataFrame({
                'feature1': np.random.rand(40),
                'feature2': np.random.rand(40),
                'feature4_num': np.random.randint(0,5,40),
                'label': np.random.choice(['Normal', 'Attack_Type1'], 40)
            }).to_csv(dummy_csv_path, index=False)
            print(f"Created/Re-created dummy data for test: {dummy_csv_path}")

        client4 = FederatedClient(client_id=4)
        loaded = client4.load_data(data_path=dummy_csv_path)

        if loaded and client4.data is not None:
            print("\n-- Training a model on Client 4 to test save/load/evaluate --")
            dummy_numeric_features = ['feature1', 'feature2', 'feature4_num']
            train_results_c4 = client4.train_local_model(
                model_type='mlp', epochs=1, verbose=0,
                feature_columns=dummy_numeric_features, target_column='label'
            )

            if train_results_c4 and client4.model:
                print(f"Client 4 trained. Accuracy on its test split: {train_results_c4['training_summary']['test_accuracy']:.4f}")

                saved_path = client4.save_model()
                if saved_path and os.path.exists(saved_path):
                    client_new_load = FederatedClient(client_id=40)
                    loaded_successfully = client_new_load.load_model(saved_path)

                    if loaded_successfully and client_new_load.model:
                        print("Model loaded into new client instance (client40) successfully.")

                        # Re-preprocess data for client4 to get its X_test, y_test for evaluation by client40
                        c4_preproc_result = client4.preprocess_data(
                            feature_columns=dummy_numeric_features, target_column='label', reshape_for_lstm=False # Assuming MLP for this test
                        )
                        if c4_preproc_result:
                            _, c4_X_test, _, c4_y_test, _ = c4_preproc_result
                            if c4_X_test is not None and c4_y_test is not None and len(c4_X_test) > 0 :
                                eval_res_loaded_model = client_new_load.evaluate_model(X_test_external=c4_X_test, y_test_external=c4_y_test)
                                if eval_res_loaded_model:
                                    print(f"Evaluation of loaded model on Client 40 (using Client 4's test data): Accuracy {eval_res_loaded_model['accuracy']:.4f}")
                            else:
                                print("Could not get Client 4's test data for evaluating loaded model on Client 40.")
                        else:
                            print("Failed to preprocess data for Client 4 to get test set for evaluation.")


                weights_c4 = client4.get_model_weights()
                if weights_c4:
                    client_for_weights_test = FederatedClient(client_id=41)
                    num_feats_c4 = len(train_results_c4['preprocessing_details']['used_feature_columns'])
                    num_classes_c4 = train_results_c4['preprocessing_details']['num_classes']
                    client_for_weights_test.model = client_for_weights_test.build_mlp_model(input_shape=(num_feats_c4,), num_classes=num_classes_c4)

                    updated = client_for_weights_test.update_model(weights_c4)
                    print(f"Weights updated on client_for_weights_test (client 41): {updated}")
                    if updated and c4_preproc_result: # Use already preprocessed data if available
                         _, c4_X_test_for_weights, _, c4_y_test_for_weights, _ = c4_preproc_result
                         if c4_X_test_for_weights is not None and c4_y_test_for_weights is not None and len(c4_X_test_for_weights) > 0:
                              eval_res_weights = client_for_weights_test.evaluate_model(X_test_external=c4_X_test_for_weights, y_test_external=c4_y_test_for_weights) # Corrected variable name
                              if eval_res_weights:
                                   print(f"Evaluation of weights-updated model on Client 41: Accuracy {eval_res_weights['accuracy']:.4f}")
                         else:
                              print("Could not get Client 4's test data for evaluating weights-updated model on Client 41.")
            else:
                print("Client 4 model training failed, skipping further tests.")
        else:
            print("Client 4 could not load data for model management tests.")
    else:
        print("⚠️ Skipping FederatedClient model management test as DATA_DIR or FL_DIR not found (Run Section 1).")
    print("\n--- End of FederatedClient Model Management Test ---")

print("\n✅ Section 4 (FederatedClient Class - Part 3: Model Management & Evaluation) is ready.")


--- Testing FederatedClient Model Management ---

-- Training a model on Client 4 to test save/load/evaluate --


  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


Client 4: Training completed. Test accuracy: 0.4500
Client 4 trained. Accuracy on its test split: 0.4500
Client 4: Model saved to /content/federated_ids_ai_project/federated_outputs/client_4_model.h5
Client 40: Model loaded from /content/federated_ids_ai_project/federated_outputs/client_4_model.h5
Model loaded into new client instance (client40) successfully.
Evaluation of loaded model on Client 40 (using Client 4's test data): Accuracy 0.4500
Weights updated on client_for_weights_test (client 41): True


  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


Evaluation of weights-updated model on Client 41: Accuracy 0.4500

--- End of FederatedClient Model Management Test ---

✅ Section 4 (FederatedClient Class - Part 3: Model Management & Evaluation) is ready.


In [6]:
# Imports from Section 1 should still be in effect.
# FederatedClient class from Sections 2, 3, 4 should be defined.
# Ensure all necessary modules like os, np, pd, LabelEncoder, StandardScaler, train_test_split,
# Adam, Sequential, LSTM, Dense, Dropout, EarlyStopping, datetime are available from previous cells or Section 1.

class FederatedClient:
    """
    Represents a client (edge device) in the federated learning system.
    Each client has its own local data and model.
    (Includes methods from Part 1, 2, 3 and this correction)
    """

    def __init__(self, client_id, data=None, model=None):
        self.client_id = client_id
        self.data = data
        self.model = model
        self.history = []
        # print(f"Client {self.client_id}: Initialized.")

    def load_data(self, data_path=None, partition_identifier=None, total_partitions=1):
        if data_path and os.path.exists(data_path):
            try:
                if data_path.endswith('.csv'): full_data_df = pd.read_csv(data_path, low_memory=False)
                elif data_path.endswith(('.json', '.jsonl')): full_data_df = pd.read_json(data_path, lines=data_path.endswith('.jsonl'))
                else: raise ValueError(f"Unsupported file format: {data_path}")

                if partition_identifier is not None and total_partitions > 0 and total_partitions <= len(full_data_df):
                    if not (0 <= partition_identifier < total_partitions): raise ValueError("partition_identifier out of range.")
                    num_samples_total = len(full_data_df)
                    samples_per_partition = num_samples_total // total_partitions
                    start_idx = partition_identifier * samples_per_partition
                    end_idx = (partition_identifier + 1) * samples_per_partition if partition_identifier < total_partitions - 1 else num_samples_total
                    self.data = full_data_df.iloc[start_idx:end_idx].copy()
                else:
                    self.data = full_data_df
                return True
            except Exception as e:
                print(f"Client {self.client_id}: Error loading data from '{data_path}': {e}"); self.data = None; return False
        elif isinstance(self.data, pd.DataFrame): return True
        elif self.data is not None:
             try: self.data = pd.DataFrame(self.data); return True
             except Exception as e: print(f"Client {self.client_id}: Could not convert pre-loaded data: {e}"); self.data=None; return False
        else: self.data = None; return False

    def preprocess_data(self, feature_columns=None, target_column=None, test_size=0.2, reshape_for_lstm=False):
        if self.data is None or self.data.empty: return None
        try:
            df_processed = self.data.copy()
            df_processed.replace([np.inf, -np.inf], np.nan, inplace=True)
            df_processed.dropna(inplace=True)
            if df_processed.empty: return None

            if target_column is None: target_column = df_processed.columns[-1]
            if target_column not in df_processed.columns: return None
            y = df_processed[target_column]

            if feature_columns is None: X = df_processed.drop(columns=[target_column])
            else:
                missing_cols = [col for col in feature_columns if col not in df_processed.columns]
                if missing_cols: print(f"Client {self.client_id}: Missing features: {missing_cols}"); return None
                X = df_processed[feature_columns]

            used_feature_columns = X.columns.tolist()
            X_numeric = X.select_dtypes(include=np.number)
            if X_numeric.shape[1] < X.shape[1]: pass
            X = X_numeric
            if X.empty: return None

            label_encoder = LabelEncoder()
            y_encoded = label_encoder.fit_transform(y)
            num_classes = len(label_encoder.classes_)

            # --- CORRECTED random_state ---
            # Use a hash of the client_id or a fixed integer for reproducibility if client_id is string
            client_seed = hash(str(self.client_id)) % (2**32 -1) # Ensure it's within integer limits for random_state

            min_samples_per_class_for_stratify = 2
            class_counts = pd.Series(y_encoded).value_counts()
            can_stratify = num_classes >= 2 and all(count >= min_samples_per_class_for_stratify for count in class_counts if count >0)

            stratify_option = y_encoded if can_stratify else None
            if len(X) < 2 :
                 print(f"Client {self.client_id}: Not enough samples ({len(X)}) to perform train/test split."); return None

            try:
                X_train, X_test, y_train, y_test = train_test_split(
                    X, y_encoded, test_size=test_size, random_state=client_seed, stratify=stratify_option # Use client_seed
                )
            except ValueError:
                X_train, X_test, y_train, y_test = train_test_split(
                    X, y_encoded, test_size=test_size, random_state=client_seed # Use client_seed
                )

            scaler = StandardScaler()
            X_train_scaled = scaler.fit_transform(X_train)
            if len(X_test) > 0:
                 X_test_scaled = scaler.transform(X_test)
            else:
                 X_test_scaled = np.array([])

            if reshape_for_lstm:
                X_train_scaled = X_train_scaled.reshape(X_train_scaled.shape[0], 1, X_train_scaled.shape[1])
                if len(X_test_scaled) > 0:
                     X_test_scaled = X_test_scaled.reshape(X_test_scaled.shape[0], 1, X_test_scaled.shape[1])

            preproc_objects = {'scaler': scaler, 'label_encoder': label_encoder, 'used_feature_columns': used_feature_columns, 'target_column_name': target_column, 'num_classes': num_classes}
            return X_train_scaled, X_test_scaled, y_train, y_test, preproc_objects
        except Exception as e:
            print(f"Client {self.client_id}: Error preprocessing: {e}"); import traceback; traceback.print_exc(); return None

    def build_lstm_model(self, input_shape, num_classes):
        model = Sequential([
            LSTM(64, input_shape=input_shape, return_sequences=True), Dropout(0.2),
            LSTM(32), Dropout(0.2), Dense(16, activation='relu'),
            Dense(num_classes, activation='softmax')])
        model.compile(optimizer=Adam(learning_rate=0.001), loss='sparse_categorical_crossentropy', metrics=['accuracy'])
        return model

    def build_mlp_model(self, input_shape, num_classes):
        model = Sequential([
            Dense(128, activation='relu', input_shape=input_shape), Dropout(0.3),
            Dense(64, activation='relu'), Dropout(0.2), Dense(32, activation='relu'),
            Dense(num_classes, activation='softmax')])
        model.compile(optimizer=Adam(learning_rate=0.001), loss='sparse_categorical_crossentropy', metrics=['accuracy'])
        return model

    def train_local_model(self, model_type='mlp', epochs=10, batch_size=32, verbose=0, feature_columns=None, target_column=None):
        if self.data is None: return None
        preproc_result = self.preprocess_data(feature_columns, target_column, reshape_for_lstm=(model_type.lower() == 'lstm'))
        if preproc_result is None: return None
        X_train, X_test, y_train, y_test, preproc_objects = preproc_result
        num_classes = preproc_objects['num_classes']
        if X_train.shape[0] == 0: return None

        if model_type.lower() == 'lstm':
            model_input_shape = (X_train.shape[1], X_train.shape[2])
            self.model = self.build_lstm_model(model_input_shape, num_classes)
        else:
            model_input_shape = (X_train.shape[1],)
            self.model = self.build_mlp_model(model_input_shape, num_classes)

        validation_data_tuple = (X_test, y_test) if len(X_test) > 0 else None
        early_stopping = EarlyStopping(monitor='val_loss' if validation_data_tuple else 'loss', patience=5, restore_best_weights=True, verbose=0)
        history_obj = self.model.fit(X_train, y_train, epochs=epochs, batch_size=batch_size, validation_data=validation_data_tuple, callbacks=[early_stopping], verbose=verbose)

        test_loss, test_accuracy = (0.0,0.0)
        if validation_data_tuple :
             test_loss, test_accuracy = self.model.evaluate(X_test, y_test, verbose=0)

        train_summary = {'client_id':self.client_id, 'model_type':model_type, 'epochs_run':len(history_obj.history['loss']), 'batch_size':batch_size, 'history':{k:[round(float(v),4) for v in val_list] for k,val_list in history_obj.history.items()}, 'test_loss':float(test_loss), 'test_accuracy':float(test_accuracy), 'timestamp':datetime.now().isoformat()}
        self.history.append(train_summary)
        # print(f"Client {self.client_id}: Training completed. Test accuracy: {test_accuracy:.4f}") # Verbose
        return {'model_weights': self.model.get_weights(), 'num_samples': X_train.shape[0] + (X_test.shape[0] if len(X_test)>0 else 0) , 'preprocessing_details': preproc_objects, 'training_summary': train_summary}

    def update_model(self, model_weights):
        if self.model is None: print(f"Client {self.client_id}: No model to update."); return False
        try: self.model.set_weights(model_weights); return True
        except Exception as e: print(f"Client {self.client_id}: Error updating weights: {e}"); return False

    def get_model_weights(self):
        if self.model is None: print(f"Client {self.client_id}: No model for weights."); return None
        return self.model.get_weights()

    def evaluate_model(self, X_test_external=None, y_test_external=None, preproc_objects_external=None):
        if self.model is None: print(f"Client {self.client_id}: No model to evaluate."); return None
        X_test_eval, y_test_eval = X_test_external, y_test_external
        if X_test_eval is None or y_test_eval is None :
            if not self.history: print(f"Client {self.client_id}: No history & no external data."); return None
            last_sum = self.history[-1]
            # print(f"Client {self.client_id}: Reporting from last local training: Acc {last_sum['test_accuracy']:.4f}") # Verbose
            return {'client_id':self.client_id, 'loss':last_sum['test_loss'], 'accuracy':last_sum['test_accuracy'], 'timestamp':datetime.now().isoformat(), 'note':'Metrics from last local training validation.'}

        if self.model.layers[0].__class__.__name__ == 'LSTM' and len(X_test_eval.shape) == 2:
            X_test_eval = X_test_eval.reshape(X_test_eval.shape[0], 1, X_test_eval.shape[1])
        if isinstance(X_test_eval, np.ndarray) and X_test_eval.size == 0: print(f"Client {self.client_id}: No test data."); return None

        loss, accuracy = self.model.evaluate(X_test_eval, y_test_eval, verbose=0)
        y_pred_probs = self.model.predict(X_test_eval, verbose=0); y_pred_classes = np.argmax(y_pred_probs, axis=1)
        report_dict={}; conf_matrix_list=[]
        try:
            labels = np.union1d(np.unique(y_test_eval), np.unique(y_pred_classes))
            report_dict = classification_report(y_test_eval,y_pred_classes,labels=labels if len(labels)>0 else None,output_dict=True,zero_division=0)
            conf_matrix_list = confusion_matrix(y_test_eval,y_pred_classes,labels=labels if len(labels)>0 else None).tolist()
        except Exception as e: print(f"Client {self.client_id}: Could not gen report/matrix: {e}")
        return {'client_id':self.client_id, 'loss':float(loss), 'accuracy':float(accuracy), 'classification_report':report_dict, 'confusion_matrix':conf_matrix_list, 'timestamp':datetime.now().isoformat()}

    def save_model(self, path=None):
        if self.model is None: print(f"Client {self.client_id}: No model to save."); return None
        if path is None: path = os.path.join(FL_DIR if 'FL_DIR' in globals() else ".", f"client_{self.client_id}_model.h5")
        try: self.model.save(path); print(f"Client {self.client_id}: Model saved to {path}"); return path
        except Exception as e: print(f"Client {self.client_id}: Error saving model: {e}"); return None

    def load_model(self, path):
        try: self.model = tf.keras.models.load_model(path); print(f"Client {self.client_id}: Model loaded from {path}"); return True
        except Exception as e: print(f"Client {self.client_id}: Error loading model from {path}: {e}"); return False

class FederatedServer:
    """
    Represents the central server in the federated learning system.
    Coordinates the training process across multiple clients.
    (Definition for Part 1 of the Server)
    """

    def __init__(self):
        self.clients = {}
        self.global_model = None
        self.global_model_type = None
        self.global_model_input_shape = None
        self.global_model_num_classes = None
        self.aggregation_history = []
        print("Federated Server: Initialized.")

    def add_client(self, client):
        if not isinstance(client, FederatedClient):
            print(f"Error: Attempted to add an object that is not a FederatedClient instance.")
            return
        if client.client_id in self.clients:
            print(f"Server: Client {client.client_id} already registered.")
        else:
            self.clients[client.client_id] = client
            print(f"Server: Added client {client.client_id}.")

    def remove_client(self, client_id):
        if client_id in self.clients:
            del self.clients[client_id]
            print(f"Server: Removed client {client_id}.")
            return True
        else:
            print(f"Server: Client {client_id} not found.")
            return False

    def initialize_global_model(self, model_type='mlp', input_shape=None, num_classes=None, client_for_shape_details=None):
        self.global_model_type = model_type.lower()

        if input_shape is None or num_classes is None:
            if client_for_shape_details and client_for_shape_details.data is not None:
                # print(f"Server: Inferring model shape from client {client_for_shape_details.client_id}...") # Verbose
                # Attempt to get feature names robustly
                temp_feature_columns = None
                if hasattr(client_for_shape_details.data, 'columns'):
                    temp_data_cols = client_for_shape_details.data.columns.tolist()
                    if 'label' in temp_data_cols: # Assuming 'label' is a common target name
                        temp_feature_columns = [col for col in temp_data_cols if col != 'label' and pd.api.types.is_numeric_dtype(client_for_shape_details.data[col])]
                    elif len(temp_data_cols) > 1:
                        temp_feature_columns = [col for col in temp_data_cols[:-1] if pd.api.types.is_numeric_dtype(client_for_shape_details.data[col])]

                preproc_result = client_for_shape_details.preprocess_data(
                    feature_columns=temp_feature_columns,
                    reshape_for_lstm=(self.global_model_type == 'lstm')
                )
                if preproc_result:
                    X_train_sample, _, _, _, preproc_objects = preproc_result
                    if input_shape is None:
                        input_shape = (X_train_sample.shape[1], X_train_sample.shape[2]) if self.global_model_type == 'lstm' else (X_train_sample.shape[1],)
                    if num_classes is None:
                        num_classes = preproc_objects['num_classes']
                    # print(f"Server: Inferred input_shape={input_shape}, num_classes={num_classes}.") # Verbose
                else:
                    # print("Server: Failed to infer from client data. Using defaults.") # Verbose
                    pass # Will fall through to defaults
            # else:
                # print("Server: No client data for shape inference. Using defaults.") # Verbose

        if input_shape is None: input_shape = (1, 10) if self.global_model_type == 'lstm' else (10,)
        if num_classes is None: num_classes = 2

        self.global_model_input_shape = input_shape
        self.global_model_num_classes = num_classes

        temp_client_for_build = FederatedClient(client_id="server_model_builder_temp")
        if self.global_model_type == 'lstm':
            self.global_model = temp_client_for_build.build_lstm_model(input_shape, num_classes)
        else:
            self.global_model = temp_client_for_build.build_mlp_model(input_shape, num_classes)
        del temp_client_for_build

        print(f"Server: Initialized global {self.global_model_type.upper()} model.")
        # self.global_model.summary() # Can be verbose

# --- Test Block for FederatedServer (Part 1: Init & Client Management) ---
# This test runs only if the cell is executed directly in Colab
if __name__ == "__main__" and 'google.colab' in sys.modules:
    print("\n--- Testing FederatedServer (Part 1) ---")

    dummy_csv_path_server_test = None
    if 'DATA_DIR' in globals() and os.path.exists(DATA_DIR):
        dummy_csv_path_server_test = os.path.join(DATA_DIR, "server_client_dummy_data.csv")
        if not os.path.exists(dummy_csv_path_server_test):
            pd.DataFrame({
                'f1_numeric': np.random.rand(50),
                'f2_numeric': np.random.rand(50),
                'target_label': np.random.choice(['ClassA', 'ClassB'], 50) # Changed target column name for clarity
            }).to_csv(dummy_csv_path_server_test, index=False)
            print(f"Created dummy data for server test: {dummy_csv_path_server_test}")
    else:
        print("⚠️ DATA_DIR not found, cannot create dummy data for server test.")

    server = FederatedServer()
    client_A = None # Initialize to None

    if dummy_csv_path_server_test:
        client_A = FederatedClient(client_id="client_A_for_server_test")
        client_A.load_data(data_path=dummy_csv_path_server_test, partition_identifier=0, total_partitions=1) # Load full dummy data

    client_B = FederatedClient(client_id="client_B_for_server_test")

    if client_A: server.add_client(client_A)
    server.add_client(client_B)

    print(f"Server has {len(server.clients)} clients registered.")

    print("\n-- Initializing Global MLP Model (inferring from client_A if possible) --")
    server.initialize_global_model(model_type='mlp', client_for_shape_details=client_A) # Pass client_A object
    if server.global_model:
        print("Global MLP Model Initialized by Server.")
        server.global_model.summary() # Print summary to verify shape

    server.remove_client("client_B_for_server_test")
    print(f"Server has {len(server.clients)} clients after removal.")

    print("\n--- End of FederatedServer (Part 1) Test ---")

print("\n✅ Section 5 (FederatedServer Class - Part 1: Init & Client Management) is ready.")


--- Testing FederatedServer (Part 1) ---
Created dummy data for server test: /content/federated_ids_ai_project/data/server_client_dummy_data.csv
Federated Server: Initialized.
Server: Added client client_A_for_server_test.
Server: Added client client_B_for_server_test.
Server has 2 clients registered.

-- Initializing Global MLP Model (inferring from client_A if possible) --
Server: Initialized global MLP model.
Global MLP Model Initialized by Server.


  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


Server: Removed client client_B_for_server_test.
Server has 1 clients after removal.

--- End of FederatedServer (Part 1) Test ---

✅ Section 5 (FederatedServer Class - Part 1: Init & Client Management) is ready.


In [7]:
# Imports from Section 1 should still be in effect.
# FederatedClient class (all parts) and FederatedServer class (Part 1) should be defined.

class FederatedServer:
    """
    Represents the central server in the federated learning system.
    Coordinates the training process across multiple clients.
    (Includes methods from Part 1 and adds new methods from Part 2)
    """

    def __init__(self):
        self.clients = {}
        self.global_model = None
        self.global_model_type = None
        self.global_model_input_shape = None
        self.global_model_num_classes = None
        self.aggregation_history = []
        # print("Federated Server: Initialized.") # Verbosity controlled

    def add_client(self, client):
        if not isinstance(client, FederatedClient): return
        if client.client_id in self.clients: return
        self.clients[client.client_id] = client
        # print(f"Server: Added client {client.client_id}.")

    def remove_client(self, client_id):
        if client_id in self.clients:
            del self.clients[client_id]; return True
        return False

    def initialize_global_model(self, model_type='mlp', input_shape=None, num_classes=None, client_for_shape_details=None):
        self.global_model_type = model_type.lower()
        final_input_shape = input_shape
        final_num_classes = num_classes

        if final_input_shape is None or final_num_classes is None:
            if client_for_shape_details and client_for_shape_details.data is not None:
                # print(f"Server: Inferring model shape from client {client_for_shape_details.client_id}...")
                temp_feature_columns = None
                if hasattr(client_for_shape_details.data, 'columns'):
                    temp_data_cols = client_for_shape_details.data.columns.tolist()
                    # Attempt to find a 'label' or 'target' column, then numeric features
                    potential_target_col = None
                    if 'label' in temp_data_cols: potential_target_col = 'label'
                    elif 'target' in temp_data_cols: potential_target_col = 'target'
                    # If no specific target, assume last for preprocessing, but features are numeric only.

                    temp_df_numeric_cols = client_for_shape_details.data.select_dtypes(include=np.number).columns.tolist()
                    if potential_target_col and potential_target_col in temp_df_numeric_cols : # if target is numeric and selected
                         temp_feature_columns = [col for col in temp_df_numeric_cols if col != potential_target_col]
                    elif temp_df_numeric_cols: # if target is not numeric or not in numeric, take all numeric as features
                         temp_feature_columns = temp_df_numeric_cols

                preproc_result = client_for_shape_details.preprocess_data(
                    feature_columns=temp_feature_columns,
                    target_column=potential_target_col, # Pass potential target
                    reshape_for_lstm=(self.global_model_type == 'lstm')
                )
                if preproc_result:
                    X_train_sample, _, _, _, preproc_objects = preproc_result
                    if final_input_shape is None:
                        final_input_shape = (X_train_sample.shape[1], X_train_sample.shape[2]) if self.global_model_type == 'lstm' else (X_train_sample.shape[1],)
                    if final_num_classes is None:
                        final_num_classes = preproc_objects['num_classes']
                    # print(f"Server: Inferred input_shape={final_input_shape}, num_classes={final_num_classes}.")

        if final_input_shape is None: final_input_shape = (1, 10) if self.global_model_type == 'lstm' else (10,)
        if final_num_classes is None: final_num_classes = 2

        self.global_model_input_shape = final_input_shape
        self.global_model_num_classes = final_num_classes

        temp_client_for_build = FederatedClient(client_id="server_model_builder_temp")
        if self.global_model_type == 'lstm':
            self.global_model = temp_client_for_build.build_lstm_model(self.global_model_input_shape, self.global_model_num_classes)
        else:
            self.global_model = temp_client_for_build.build_mlp_model(self.global_model_input_shape, self.global_model_num_classes)
        del temp_client_for_build

        print(f"Server: Initialized global {self.global_model_type.upper()} model (Input: {self.global_model_input_shape}, Classes: {self.global_model_num_classes}).")
        # self.global_model.summary()


    # --- New methods for Section 6 ---
    def federated_averaging(self, client_model_weights_list, client_data_sizes=None):
        """
        Perform federated averaging (FedAvg) on a list of client model weights.

        Args:
            client_model_weights_list (list): A list where each element is the weights
                                             (list of numpy arrays) from a client's model.
            client_data_sizes (list, optional): A list of integers representing the number
                                                of data samples each client used for training.
                                                If None, equal weighting is applied.
        Returns:
            list: The new global model weights (averaged).
        """
        if not client_model_weights_list:
            print("Server: No client weights provided for averaging.")
            return None

        # Initialize sum of weights with zeros, matching the structure of the first client's weights
        avg_weights = [np.zeros_like(layer_weights) for layer_weights in client_model_weights_list[0]]

        total_data_size = 0
        if client_data_sizes:
            if len(client_model_weights_list) != len(client_data_sizes):
                print("Server: Mismatch between number of client weights and data sizes. Using equal weighting.")
                client_data_sizes = None # Fallback to equal weighting
            else:
                total_data_size = sum(client_data_sizes)

        if total_data_size == 0 or client_data_sizes is None: # Handles equal weighting or empty data sizes
            print("Server: Using equal weighting for federated averaging.")
            # Equal weighting
            num_clients_with_weights = len(client_model_weights_list)
            for client_weights in client_model_weights_list:
                for i, layer_weights in enumerate(client_weights):
                    avg_weights[i] += layer_weights / num_clients_with_weights
        else:
            # Weighted averaging based on data size
            print(f"Server: Using weighted averaging based on data sizes (total samples: {total_data_size}).")
            for idx, client_weights in enumerate(client_model_weights_list):
                weight_factor = client_data_sizes[idx] / total_data_size
                for i, layer_weights in enumerate(client_weights):
                    avg_weights[i] += layer_weights * weight_factor

        print("Server: Federated averaging completed.")
        return avg_weights

    def train_round(self, round_number, num_selected_clients=None, epochs_per_client=5, batch_size_per_client=32):
        """
        Conducts a single round of federated training.
        1. Selects a subset of available clients.
        2. Sends the current global model weights to these clients.
        3. Clients train the model on their local data.
        4. Server collects updated model weights from clients.
        5. Server aggregates these weights (e.g., FedAvg) to update the global model.
        """
        if self.global_model is None:
            print("Server: Global model not initialized. Cannot start training round.")
            return None
        if not self.clients:
            print("Server: No clients registered. Cannot start training round.")
            return None

        # Select clients for this round
        available_client_ids = list(self.clients.keys())
        if not available_client_ids:
            print("Server: No available clients for this round.")
            return None

        if num_selected_clients is None:
            num_selected_clients = len(available_client_ids) # Use all available clients

        selected_client_ids = random.sample(
            available_client_ids,
            min(num_selected_clients, len(available_client_ids)) # Ensure not to select more than available
        )

        if not selected_client_ids:
            print("Server: No clients were selected for this training round.")
            return None

        print(f"\n--- Server: Starting Federated Training Round {round_number} ---")
        print(f"Server: Selected {len(selected_client_ids)} clients for this round: {selected_client_ids}")

        current_global_weights = self.global_model.get_weights()
        collected_client_weights = []
        client_data_sizes_for_round = []
        successful_clients_this_round = []

        for client_id in selected_client_ids:
            client = self.clients[client_id]
            print(f"Server: Training client {client.client_id}...")

            # Ensure client has a model structure compatible with global model
            # If client has no model, or if model type/shape mis-match, re-initialize from global
            if client.model is None or \
               client.model.layers[0].input_shape[1:] != self.global_model_input_shape or \
               client.model.layers[-1].output_shape[-1] != self.global_model_num_classes:
                print(f"Client {client.client_id}: Model structure mismatch or not initialized. Rebuilding from global specs.")
                temp_builder_client = FederatedClient(client_id="temp_builder") # Use dummy for build methods
                if self.global_model_type == 'lstm':
                    client.model = temp_builder_client.build_lstm_model(self.global_model_input_shape, self.global_model_num_classes)
                else:
                    client.model = temp_builder_client.build_mlp_model(self.global_model_input_shape, self.global_model_num_classes)
                del temp_builder_client

            client.update_model(current_global_weights) # Send global model to client

            # Client trains locally
            # Assuming client data is already loaded. Feature/target columns need to be consistent.
            # For this simulation, let's assume standard feature/target column names if not specified.
            training_result = client.train_local_model(
                model_type=self.global_model_type,
                epochs=epochs_per_client,
                batch_size=batch_size_per_client,
                verbose=0 # Keep client training quiet for server logs
            )

            if training_result and 'model_weights' in training_result:
                collected_client_weights.append(training_result['model_weights'])
                client_data_sizes_for_round.append(training_result['num_samples'])
                successful_clients_this_round.append(client_id)
                print(f"Client {client.client_id}: Training successful. Accuracy: {training_result['training_summary']['test_accuracy']:.4f}")
            else:
                print(f"Client {client.client_id}: Training failed or returned no weights.")

        if not collected_client_weights:
            print("Server: No weights collected from clients in this round. Global model not updated.")
            return {'round': round_number, 'status': 'failed_no_weights', 'successful_clients': []}

        # Aggregate weights (Federated Averaging)
        new_global_weights = self.federated_averaging(collected_client_weights, client_data_sizes_for_round)
        if new_global_weights:
            self.global_model.set_weights(new_global_weights)
            print("Server: Global model updated with aggregated weights.")

        round_summary = {
            'round': round_number,
            'status': 'completed',
            'num_selected_clients': len(selected_client_ids),
            'successful_clients': successful_clients_this_round,
            'client_data_sizes': client_data_sizes_for_round
        }
        self.aggregation_history.append(round_summary)
        return round_summary

# --- Test Block for FederatedServer (Part 2: Training Round) ---
if __name__ == "__main__" and 'google.colab' in sys.modules:
    print("\n--- Testing FederatedServer (Part 2: Training Round) ---")

    # Ensure DATA_DIR and FL_DIR are available from Section 1
    if 'DATA_DIR' not in globals() or not os.path.exists(DATA_DIR) or \
       'FL_DIR' not in globals() or not os.path.exists(FL_DIR):
        print("⚠️ DATA_DIR or FL_DIR not found. Skipping training round test. Run Section 1 first.")
    else:
        # Re-create server and clients for a clean test
        server_test_part2 = FederatedServer()

        # Create dummy data for clients if it doesn't exist
        dummy_data_paths = []
        num_test_clients = 3
        for i in range(num_test_clients):
            client_data_path = os.path.join(DATA_DIR, f"client_{i}_data_part2.csv")
            dummy_data_paths.append(client_data_path)
            if not os.path.exists(client_data_path):
                pd.DataFrame({
                    'f1': np.random.rand(100 + i*20), # Varying data sizes
                    'f2': np.random.rand(100 + i*20),
                    'label': np.random.choice(['Class0', 'Class1'], 100 + i*20)
                }).to_csv(client_data_path, index=False)
                print(f"Created dummy data for client {i} at {client_data_path}")

        # Add clients
        client_objects_for_test = []
        for i in range(num_test_clients):
            client = FederatedClient(client_id=f"client_train_test_{i}")
            client.load_data(data_path=dummy_data_paths[i]) # Load entire dummy file for each
            server_test_part2.add_client(client)
            client_objects_for_test.append(client)

        if client_objects_for_test and client_objects_for_test[0].data is not None:
            print("\n-- Initializing Global MLP Model for Round Test --")
            server_test_part2.initialize_global_model(
                model_type='mlp',
                client_for_shape_details=client_objects_for_test[0] # Use first client for shape
            )

            if server_test_part2.global_model:
                print("\n-- Running one round of Federated Training --")
                # Use specific feature/target names for dummy data
                round_1_results = server_test_part2.train_round(
                    round_number=1,
                    num_selected_clients=2, # Select 2 out of 3 clients
                    epochs_per_client=2,    # Short epochs
                    batch_size_per_client=16
                )
                if round_1_results:
                    print(f"Round 1 summary: {round_1_results}")
            else:
                print("Global model not initialized, cannot run training round.")
        else:
            print("Failed to load data for client_objects_for_test[0], cannot initialize global model for round test.")

    print("\n--- End of FederatedServer (Part 2) Test ---")

print("\n✅ Section 6 (FederatedServer Class - Part 2: Federated Averaging & Training Round) is ready.")


--- Testing FederatedServer (Part 2: Training Round) ---
Created dummy data for client 0 at /content/federated_ids_ai_project/data/client_0_data_part2.csv
Created dummy data for client 1 at /content/federated_ids_ai_project/data/client_1_data_part2.csv
Created dummy data for client 2 at /content/federated_ids_ai_project/data/client_2_data_part2.csv

-- Initializing Global MLP Model for Round Test --
Server: Initialized global MLP model (Input: (2,), Classes: 2).

-- Running one round of Federated Training --

--- Server: Starting Federated Training Round 1 ---
Server: Selected 2 clients for this round: ['client_train_test_0', 'client_train_test_2']
Server: Training client client_train_test_0...
Client client_train_test_0: Model structure mismatch or not initialized. Rebuilding from global specs.


  super().__init__(activity_regularizer=activity_regularizer, **kwargs)
  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


Client client_train_test_0: Training successful. Accuracy: 0.4500
Server: Training client client_train_test_2...
Client client_train_test_2: Model structure mismatch or not initialized. Rebuilding from global specs.


  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


Client client_train_test_2: Training successful. Accuracy: 0.4286
Server: Using weighted averaging based on data sizes (total samples: 240).
Server: Federated averaging completed.
Server: Global model updated with aggregated weights.
Round 1 summary: {'round': 1, 'status': 'completed', 'num_selected_clients': 2, 'successful_clients': ['client_train_test_0', 'client_train_test_2'], 'client_data_sizes': [100, 140]}

--- End of FederatedServer (Part 2) Test ---

✅ Section 6 (FederatedServer Class - Part 2: Federated Averaging & Training Round) is ready.


In [8]:
# Imports from Section 1 should still be in effect.
# FederatedClient class (all parts) and FederatedServer class (Parts 1 & 2) should be defined.

class FederatedServer:
    """
    Represents the central server in the federated learning system.
    Coordinates the training process across multiple clients.
    (Includes methods from Part 1 & 2, and adds new methods from Part 3)
    """

    def __init__(self):
        self.clients = {}
        self.global_model = None
        self.global_model_type = None
        self.global_model_input_shape = None
        self.global_model_num_classes = None
        self.aggregation_history = []
        # print("Federated Server: Initialized.")

    def add_client(self, client):
        if not isinstance(client, FederatedClient): return
        if client.client_id in self.clients: return
        self.clients[client.client_id] = client
        # print(f"Server: Added client {client.client_id}.")

    def remove_client(self, client_id):
        if client_id in self.clients:
            del self.clients[client_id]; return True
        return False

    def initialize_global_model(self, model_type='mlp', input_shape=None, num_classes=None, client_for_shape_details=None):
        self.global_model_type = model_type.lower()
        final_input_shape = input_shape
        final_num_classes = num_classes

        if final_input_shape is None or final_num_classes is None:
            if client_for_shape_details and client_for_shape_details.data is not None:
                temp_feature_columns = None
                if hasattr(client_for_shape_details.data, 'columns'):
                    temp_data_cols = client_for_shape_details.data.columns.tolist()
                    potential_target_col = None
                    if 'label' in temp_data_cols: potential_target_col = 'label'
                    elif 'target' in temp_data_cols: potential_target_col = 'target'

                    temp_df_numeric_cols = client_for_shape_details.data.select_dtypes(include=np.number).columns.tolist()
                    if potential_target_col and potential_target_col in temp_df_numeric_cols :
                         temp_feature_columns = [col for col in temp_df_numeric_cols if col != potential_target_col]
                    elif temp_df_numeric_cols:
                         temp_feature_columns = temp_df_numeric_cols

                preproc_result = client_for_shape_details.preprocess_data(
                    feature_columns=temp_feature_columns,
                    target_column=potential_target_col,
                    reshape_for_lstm=(self.global_model_type == 'lstm')
                )
                if preproc_result:
                    X_train_sample, _, _, _, preproc_objects = preproc_result
                    if final_input_shape is None:
                        final_input_shape = (X_train_sample.shape[1], X_train_sample.shape[2]) if self.global_model_type == 'lstm' else (X_train_sample.shape[1],)
                    if final_num_classes is None:
                        final_num_classes = preproc_objects['num_classes']

        if final_input_shape is None: final_input_shape = (1, 10) if self.global_model_type == 'lstm' else (10,)
        if final_num_classes is None: final_num_classes = 2

        self.global_model_input_shape = final_input_shape
        self.global_model_num_classes = final_num_classes

        temp_client_for_build = FederatedClient(client_id="server_model_builder_temp")
        if self.global_model_type == 'lstm':
            self.global_model = temp_client_for_build.build_lstm_model(self.global_model_input_shape, self.global_model_num_classes)
        else:
            self.global_model = temp_client_for_build.build_mlp_model(self.global_model_input_shape, self.global_model_num_classes)
        del temp_client_for_build

        # print(f"Server: Initialized global {self.global_model_type.upper()} model (Input: {self.global_model_input_shape}, Classes: {self.global_model_num_classes}).")


    def federated_averaging(self, client_model_weights_list, client_data_sizes=None):
        if not client_model_weights_list: return None
        avg_weights = [np.zeros_like(w) for w in client_model_weights_list[0]]
        total_data_size = 0
        if client_data_sizes:
            if len(client_model_weights_list) != len(client_data_sizes): client_data_sizes = None
            else: total_data_size = sum(client_data_sizes)

        if total_data_size == 0 or client_data_sizes is None:
            num_clients_with_weights = len(client_model_weights_list)
            for client_weights in client_model_weights_list:
                for i, layer_weights in enumerate(client_weights): avg_weights[i] += layer_weights / num_clients_with_weights
        else:
            for idx, client_weights in enumerate(client_model_weights_list):
                weight_factor = client_data_sizes[idx] / total_data_size
                for i, layer_weights in enumerate(client_weights): avg_weights[i] += layer_weights * weight_factor
        return avg_weights

    def train_round(self, round_number, num_selected_clients=None, epochs_per_client=5, batch_size_per_client=32):
        if self.global_model is None or not self.clients: return None
        available_client_ids = list(self.clients.keys())
        if not available_client_ids: return None
        if num_selected_clients is None: num_selected_clients = len(available_client_ids)
        selected_client_ids = random.sample(available_client_ids, min(num_selected_clients, len(available_client_ids)))
        if not selected_client_ids: return None

        print(f"\n--- Server: Starting FL Round {round_number} with {len(selected_client_ids)} clients: {selected_client_ids} ---")
        current_global_weights = self.global_model.get_weights()
        collected_client_weights, client_data_sizes_for_round, successful_clients_this_round = [], [], []

        for client_id in selected_client_ids:
            client = self.clients[client_id]
            # print(f"Server: Training client {client.client_id}...") # Verbose
            if client.model is None or \
               client.model.layers[0].input_shape[1:] != self.global_model_input_shape or \
               client.model.layers[-1].output_shape[-1] != self.global_model_num_classes:
                temp_builder = FederatedClient(client_id="temp")
                client.model = temp_builder.build_lstm_model(self.global_model_input_shape, self.global_model_num_classes) if self.global_model_type == 'lstm' \
                    else temp_builder.build_mlp_model(self.global_model_input_shape, self.global_model_num_classes)
                del temp_builder
            client.update_model(current_global_weights)

            training_result = client.train_local_model(model_type=self.global_model_type, epochs=epochs_per_client, batch_size=batch_size_per_client, verbose=0)
            if training_result and 'model_weights' in training_result:
                collected_client_weights.append(training_result['model_weights'])
                client_data_sizes_for_round.append(training_result['num_samples'])
                successful_clients_this_round.append(client_id)
                # print(f"Client {client.client_id}: Train OK. Acc: {training_result['training_summary']['test_accuracy']:.4f}") # Verbose
            # else: print(f"Client {client.client_id}: Training failed.") # Verbose

        if not collected_client_weights: print("Server: No weights from clients this round."); return {'round':round_number, 'status':'failed_no_weights', 'successful_clients':[]}

        new_global_weights = self.federated_averaging(collected_client_weights, client_data_sizes_for_round)
        if new_global_weights: self.global_model.set_weights(new_global_weights); print("Server: Global model updated via FedAvg.")

        round_summary = {'round':round_number, 'status':'completed', 'num_selected':len(selected_client_ids), 'successful_clients':successful_clients_this_round, 'client_data_sizes':client_data_sizes_for_round}
        self.aggregation_history.append(round_summary)
        return round_summary

    # --- New methods for Section 7 ---
    def evaluate_global_model(self, X_test_global, y_test_global):
        """
        Evaluate the current global model on a provided global test dataset.
        """
        if self.global_model is None:
            print("Server: Global model not initialized. Cannot evaluate.")
            return None

        if X_test_global is None or y_test_global is None or (isinstance(X_test_global, np.ndarray) and X_test_global.size == 0):
            print("Server: No global test data provided for evaluation.")
            return None

        # Reshape X_test_global if the global model is LSTM and data is not already shaped
        X_test_to_eval = X_test_global
        if self.global_model_type == 'lstm' and len(X_test_to_eval.shape) == 2:
            X_test_to_eval = X_test_to_eval.reshape(X_test_to_eval.shape[0], 1, X_test_to_eval.shape[1])

        print(f"Server: Evaluating global model on provided test set (shape: {X_test_to_eval.shape})...")
        loss, accuracy = self.global_model.evaluate(X_test_to_eval, y_test_global, verbose=0)

        y_pred_probs = self.global_model.predict(X_test_to_eval, verbose=0)
        y_pred_classes = np.argmax(y_pred_probs, axis=1)

        report_dict = {}; conf_matrix_list = []
        try:
            labels = np.union1d(np.unique(y_test_global), np.unique(y_pred_classes))
            report_dict = classification_report(y_test_global, y_pred_classes, labels=labels if len(labels)>0 else None, output_dict=True, zero_division=0)
            conf_matrix_list = confusion_matrix(y_test_global, y_pred_classes, labels=labels if len(labels)>0 else None).tolist()
        except Exception as e:
            print(f"Server: Could not generate full report/matrix for global model: {e}")

        eval_result = {
            'timestamp': datetime.now().isoformat(),
            'loss': float(loss), 'accuracy': float(accuracy),
            'classification_report': report_dict,
            'confusion_matrix': conf_matrix_list
        }
        print(f"Server: Global Model Evaluation - Accuracy: {accuracy:.4f}, Loss: {loss:.4f}")
        return eval_result

    def save_global_model(self, path=None):
        """Save the global model to a file."""
        if self.global_model is None: print("Server: No global model to save."); return None
        if path is None:
            # MODEL_DIR should be defined from Section 1
            path = os.path.join(MODEL_DIR if 'MODEL_DIR' in globals() else ".",
                                f"fl_global_{self.global_model_type}_model_{datetime.now().strftime('%Y%m%d_%H%M%S')}.h5")
        try:
            self.global_model.save(path)
            print(f"Server: Global model saved to {path}")
            return path
        except Exception as e: print(f"Server: Error saving global model: {e}"); return None

    def load_global_model(self, path):
        """Load a global model from a file."""
        try:
            self.global_model = tf.keras.models.load_model(path) # Use tf.keras.models.load_model
            # Infer type, shape, classes if possible (might need to be stored alongside model)
            self.global_model_type = 'lstm' if any(isinstance(layer, LSTM) for layer in self.global_model.layers) else 'mlp'
            self.global_model_input_shape = self.global_model.layers[0].input_shape[1:] # Exclude batch size
            self.global_model_num_classes = self.global_model.layers[-1].output_shape[-1]
            print(f"Server: Global model loaded from {path} (Type: {self.global_model_type}, Input: {self.global_model_input_shape}, Classes: {self.global_model_num_classes})")
            return True
        except Exception as e: print(f"Server: Error loading global model from {path}: {e}"); return False

    def save_history(self, path=None):
        """Save the aggregation history to a file."""
        if not self.aggregation_history: print("Server: No aggregation history to save."); return None
        if path is None:
            # FL_DIR should be defined from Section 1
            path = os.path.join(FL_DIR if 'FL_DIR' in globals() else ".",
                                f"fl_aggregation_history_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json")
        try:
            with open(path, 'w') as f: json.dump(self.aggregation_history, f, indent=2)
            print(f"Server: Aggregation history saved to {path}")
            return path
        except Exception as e: print(f"Server: Error saving aggregation history: {e}"); return None

# --- Test Block for FederatedServer (Part 3: Evaluation & Saving) ---
if __name__ == "__main__" and 'google.colab' in sys.modules:
    print("\n--- Testing FederatedServer (Part 3: Evaluation & Saving) ---")

    if 'DATA_DIR' not in globals() or not os.path.exists(DATA_DIR) or \
       'MODEL_DIR' not in globals() or not os.path.exists(MODEL_DIR) or \
       'FL_DIR' not in globals() or not os.path.exists(FL_DIR):
        print("⚠️ DATA_DIR, MODEL_DIR or FL_DIR not found. Skipping test. Run Section 1 first.")
    elif 'FederatedClient' not in globals() or 'FederatedServer' not in globals():
        print("⚠️ FederatedClient or FederatedServer class not defined. Run previous sections.")
    else:
        # Use existing server_test_part2 if available from Section 6 test, or create new
        if 'server_test_part2' in locals() and isinstance(server_test_part2, FederatedServer) and server_test_part2.global_model:
            server_eval_test = server_test_part2
            print("Using server instance from Part 2 test.")
        else:
            print("Creating new server instance for Part 3 test.")
            server_eval_test = FederatedServer()
            # Setup a client and initialize global model if server_test_part2 was not run/successful
            client_for_init = FederatedClient(client_id="init_client_for_eval_test")
            dummy_data_path_eval = os.path.join(DATA_DIR, "server_client_dummy_data.csv") # From previous test
            if os.path.exists(dummy_data_path_eval):
                client_for_init.load_data(data_path=dummy_data_path_eval)
                server_eval_test.initialize_global_model(model_type='mlp', client_for_shape_details=client_for_init)
            else: # Fallback if dummy data is missing
                server_eval_test.initialize_global_model(model_type='mlp', input_shape=(10,), num_classes=2)


        if server_eval_test.global_model:
            print("\n-- Testing Global Model Evaluation --")
            # Create some dummy global test data
            # Shape should match server_eval_test.global_model_input_shape
            eval_input_shape = server_eval_test.global_model_input_shape
            num_eval_samples = 50

            if server_eval_test.global_model_type == 'lstm':
                # eval_input_shape is (timesteps, features) e.g. (1, 10)
                X_global_test = np.random.rand(num_eval_samples, eval_input_shape[0], eval_input_shape[1])
            else: # MLP
                # eval_input_shape is (features,) e.g. (10,)
                X_global_test = np.random.rand(num_eval_samples, eval_input_shape[0])

            y_global_test = np.random.randint(0, server_eval_test.global_model_num_classes, num_eval_samples)

            eval_results = server_eval_test.evaluate_global_model(X_global_test, y_global_test)
            if eval_results:
                print(f"Global model evaluation result (on random data): Accuracy {eval_results['accuracy']:.4f}")

            print("\n-- Testing Save/Load Global Model --")
            saved_model_path = server_eval_test.save_global_model()
            if saved_model_path and os.path.exists(saved_model_path):
                server_new_load_test = FederatedServer()
                loaded = server_new_load_test.load_global_model(saved_model_path)
                if loaded and server_new_load_test.global_model:
                    print("Global model loaded into new server instance successfully.")
                    # Quick eval to check if loaded model works
                    re_eval_results = server_new_load_test.evaluate_global_model(X_global_test, y_global_test)
                    if re_eval_results:
                         print(f"Re-evaluation of loaded global model: Accuracy {re_eval_results['accuracy']:.4f}")

            print("\n-- Testing Save History --")
            # Add a dummy aggregation if history is empty for testing save
            if not server_eval_test.aggregation_history:
                server_eval_test.aggregation_history.append({'round':0, 'status':'dummy_for_save_test'})
            server_eval_test.save_history()
        else:
            print("Global model not initialized on server_eval_test. Skipping Part 3 tests.")

    print("\n--- End of FederatedServer (Part 3) Test ---")

print("\n✅ Section 7 (FederatedServer Class - Part 3: Evaluation & Saving) is ready.")


--- Testing FederatedServer (Part 3: Evaluation & Saving) ---
Creating new server instance for Part 3 test.

-- Testing Global Model Evaluation --
Server: Evaluating global model on provided test set (shape: (50, 2))...


  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


Server: Global Model Evaluation - Accuracy: 0.4600, Loss: 0.7000
Global model evaluation result (on random data): Accuracy 0.4600

-- Testing Save/Load Global Model --
Server: Global model saved to /content/federated_ids_ai_project/models/fl_global_mlp_model_20250528_234344.h5
Server: Error loading global model from /content/federated_ids_ai_project/models/fl_global_mlp_model_20250528_234344.h5: 'Dense' object has no attribute 'input_shape'

-- Testing Save History --
Server: Aggregation history saved to /content/federated_ids_ai_project/federated_outputs/fl_aggregation_history_20250528_234344.json

--- End of FederatedServer (Part 3) Test ---

✅ Section 7 (FederatedServer Class - Part 3: Evaluation & Saving) is ready.


In [9]:
# Imports from Section 1 should still be in effect.
# FederatedClient class (all parts) and FederatedServer class (Parts 1 & 2) should be defined.

class FederatedServer:
    """
    Represents the central server in the federated learning system.
    Coordinates the training process across multiple clients.
    (Includes methods from Part 1 & 2, and adds new methods from Part 3)
    """

    def __init__(self):
        self.clients = {}
        self.global_model = None
        self.global_model_type = None
        # For MLP, input_shape will be like (num_features,). For LSTM, (timesteps, num_features) e.g. (1, num_features)
        self.global_model_input_shape_for_build = None
        self.global_model_num_classes = None
        self.aggregation_history = []
        # print("Federated Server: Initialized.")

    def add_client(self, client):
        if not isinstance(client, FederatedClient): return
        if client.client_id in self.clients: return
        self.clients[client.client_id] = client

    def remove_client(self, client_id):
        if client_id in self.clients:
            del self.clients[client_id]; return True
        return False

    def initialize_global_model(self, model_type='mlp', input_shape=None, num_classes=None, client_for_shape_details=None):
        self.global_model_type = model_type.lower()
        final_input_shape_for_build = input_shape # This is the shape passed to model constructor
        final_num_classes = num_classes

        if final_input_shape_for_build is None or final_num_classes is None:
            if client_for_shape_details and client_for_shape_details.data is not None:
                temp_feature_columns = None
                if hasattr(client_for_shape_details.data, 'columns'):
                    temp_data_cols = client_for_shape_details.data.columns.tolist()
                    potential_target_col = 'label' if 'label' in temp_data_cols else ('target' if 'target' in temp_data_cols else None)
                    temp_df_numeric_cols = client_for_shape_details.data.select_dtypes(include=np.number).columns.tolist()
                    if potential_target_col and potential_target_col in temp_df_numeric_cols :
                         temp_feature_columns = [col for col in temp_df_numeric_cols if col != potential_target_col]
                    elif temp_df_numeric_cols:
                         temp_feature_columns = temp_df_numeric_cols

                preproc_result = client_for_shape_details.preprocess_data(
                    feature_columns=temp_feature_columns,
                    target_column=potential_target_col,
                    reshape_for_lstm=False # Get basic X_train_sample shape first
                )
                if preproc_result:
                    X_train_sample, _, _, _, preproc_objects = preproc_result
                    if final_input_shape_for_build is None:
                        # For MLP: (num_features,)
                        # For LSTM: (timesteps, num_features) e.g., (1, num_features)
                        num_features = X_train_sample.shape[1]
                        final_input_shape_for_build = (1, num_features) if self.global_model_type == 'lstm' else (num_features,)
                    if final_num_classes is None:
                        final_num_classes = preproc_objects['num_classes']

        if final_input_shape_for_build is None: final_input_shape_for_build = (1, 10) if self.global_model_type == 'lstm' else (10,)
        if final_num_classes is None: final_num_classes = 2

        self.global_model_input_shape_for_build = final_input_shape_for_build
        self.global_model_num_classes = final_num_classes

        temp_client_for_build = FederatedClient(client_id="server_model_builder_temp")
        if self.global_model_type == 'lstm':
            self.global_model = temp_client_for_build.build_lstm_model(self.global_model_input_shape_for_build, self.global_model_num_classes)
        else: # MLP
            self.global_model = temp_client_for_build.build_mlp_model(self.global_model_input_shape_for_build, self.global_model_num_classes)
        del temp_client_for_build

        print(f"Server: Initialized global {self.global_model_type.upper()} model (Build Input Shape: {self.global_model_input_shape_for_build}, Classes: {self.global_model_num_classes}).")


    def federated_averaging(self, client_model_weights_list, client_data_sizes=None):
        if not client_model_weights_list: return None
        avg_weights = [np.zeros_like(w) for w in client_model_weights_list[0]]
        total_data_size = 0
        if client_data_sizes:
            if len(client_model_weights_list) != len(client_data_sizes): client_data_sizes = None
            else: total_data_size = sum(client_data_sizes)

        if total_data_size == 0 or client_data_sizes is None:
            num_clients_with_weights = len(client_model_weights_list)
            for client_weights in client_model_weights_list:
                for i, layer_weights in enumerate(client_weights): avg_weights[i] += layer_weights / num_clients_with_weights
        else:
            for idx, client_weights in enumerate(client_model_weights_list):
                weight_factor = client_data_sizes[idx] / total_data_size
                for i, layer_weights in enumerate(client_weights): avg_weights[i] += layer_weights * weight_factor
        return avg_weights

    def train_round(self, round_number, num_selected_clients=None, epochs_per_client=5, batch_size_per_client=32):
        if self.global_model is None or not self.clients: return None
        available_client_ids = list(self.clients.keys())
        if not available_client_ids: return None
        if num_selected_clients is None: num_selected_clients = len(available_client_ids)
        selected_client_ids = random.sample(available_client_ids, min(num_selected_clients, len(available_client_ids)))
        if not selected_client_ids: return None

        print(f"\n--- Server: Starting FL Round {round_number} with {len(selected_client_ids)} clients: {selected_client_ids} ---")
        current_global_weights = self.global_model.get_weights()
        collected_client_weights, client_data_sizes_for_round, successful_clients_this_round = [], [], []

        for client_id in selected_client_ids:
            client = self.clients[client_id]

            # Ensure client has a model structure compatible with global model
            # Global model input shape for build must be used here
            if client.model is None or \
               client.model.layers[0].input_shape[1:] != self.global_model_input_shape_for_build or \
               client.model.layers[-1].output_shape[-1] != self.global_model_num_classes:
                temp_builder = FederatedClient(client_id="temp_model_builder_for_client")
                if self.global_model_type == 'lstm':
                    client.model = temp_builder.build_lstm_model(self.global_model_input_shape_for_build, self.global_model_num_classes)
                else:
                    client.model = temp_builder.build_mlp_model(self.global_model_input_shape_for_build, self.global_model_num_classes)
                del temp_builder

            client.update_model(current_global_weights)

            training_result = client.train_local_model(model_type=self.global_model_type, epochs=epochs_per_client, batch_size=batch_size_per_client, verbose=0)
            if training_result and 'model_weights' in training_result:
                collected_client_weights.append(training_result['model_weights'])
                client_data_sizes_for_round.append(training_result['num_samples'])
                successful_clients_this_round.append(client_id)
                print(f"Client {client.client_id}: Training successful. Test Acc: {training_result['training_summary']['test_accuracy']:.4f}")

        if not collected_client_weights: print("Server: No weights collected."); return {'round':round_number, 'status':'failed_no_weights', 'successful_clients':[]}

        new_global_weights = self.federated_averaging(collected_client_weights, client_data_sizes_for_round)
        if new_global_weights: self.global_model.set_weights(new_global_weights); print("Server: Global model updated via FedAvg.")

        round_summary = {'round':round_number, 'status':'completed', 'num_selected':len(selected_client_ids), 'successful_clients':successful_clients_this_round, 'client_data_sizes':client_data_sizes_for_round}
        self.aggregation_history.append(round_summary)
        return round_summary

    def evaluate_global_model(self, X_test_global, y_test_global):
        if self.global_model is None: print("Server: Global model not init."); return None
        if X_test_global is None or y_test_global is None or (isinstance(X_test_global, np.ndarray) and X_test_global.size == 0):
            print("Server: No global test data provided."); return None

        X_test_to_eval = X_test_global
        is_lstm_model = any(isinstance(layer, LSTM) for layer in self.global_model.layers)

        if is_lstm_model and len(X_test_to_eval.shape) == 2:
             # Expected LSTM input e.g. (1, num_features) from self.global_model_input_shape_for_build
             # X_test_to_eval is likely (num_samples, num_features)
             if self.global_model_input_shape_for_build and len(self.global_model_input_shape_for_build) == 2:
                timesteps = self.global_model_input_shape_for_build[0] # Should be 1
                num_features = self.global_model_input_shape_for_build[1]
                if X_test_to_eval.shape[1] == num_features: # Check if feature count matches
                    X_test_to_eval = X_test_to_eval.reshape(X_test_to_eval.shape[0], timesteps, num_features)
                else:
                    print(f"Server: Feature mismatch for LSTM reshaping. Expected {num_features}, got {X_test_to_eval.shape[1]}.")
                    return None # Cannot reshape correctly

        # print(f"Server: Evaluating global model on test set (shape: {X_test_to_eval.shape})...")
        loss, accuracy = self.global_model.evaluate(X_test_to_eval, y_test_global, verbose=0)
        y_pred_probs = self.global_model.predict(X_test_to_eval, verbose=0); y_pred_classes = np.argmax(y_pred_probs, axis=1)
        report_dict = {}; conf_matrix_list = []
        try:
            labels = np.union1d(np.unique(y_test_global), np.unique(y_pred_classes))
            report_dict = classification_report(y_test_global,y_pred_classes,labels=labels if len(labels)>0 else None,output_dict=True,zero_division=0)
            conf_matrix_list = confusion_matrix(y_test_global,y_pred_classes,labels=labels if len(labels)>0 else None).tolist()
        except Exception as e: print(f"Server: Could not gen report/matrix: {e}")
        eval_result = {'timestamp': datetime.now().isoformat(), 'loss': float(loss), 'accuracy': float(accuracy), 'classification_report': report_dict, 'confusion_matrix': conf_matrix_list}
        print(f"Server: Global Model Eval - Accuracy: {accuracy:.4f}, Loss: {loss:.4f}")
        return eval_result

    def save_global_model(self, path=None):
        if self.global_model is None: print("Server: No global model."); return None
        if path is None:
            path = os.path.join(MODEL_DIR if 'MODEL_DIR' in globals() else ".",
                                f"fl_global_{self.global_model_type}_model_{datetime.now().strftime('%Y%m%d_%H%M%S')}.h5")
        try: self.global_model.save(path); print(f"Server: Global model saved to {path}"); return path
        except Exception as e: print(f"Server: Error saving global model: {e}"); return None

    def load_global_model(self, path):
        try:
            self.global_model = tf.keras.models.load_model(path)
            print(f"Server: Global model loaded from {path}")

            # Infer model type based on layers
            self.global_model_type = 'lstm' if any(isinstance(layer, LSTM) for layer in self.global_model.layers) else 'mlp'

            # Infer input shape and num_classes from the loaded model's config
            # This is generally more reliable for models saved in .h5 format
            config = self.global_model.get_config()

            # For input_shape (shape passed to constructor, e.g. (1, num_features) for LSTM or (num_features,) for MLP)
            if 'layers' in config and config['layers']:
                first_layer_config = config['layers'][0]['config']
                if 'batch_input_shape' in first_layer_config:
                    # batch_input_shape is (None, timesteps, features) or (None, features)
                    self.global_model_input_shape_for_build = first_layer_config['batch_input_shape'][1:]
                else: # Fallback if batch_input_shape is not there (e.g. functional API model with Input layer)
                    try: # For functional API Input Layer
                         self.global_model_input_shape_for_build = self.global_model.input_shape[1:]
                    except AttributeError:
                         print("Server Warning: Could not reliably infer input_shape from loaded model config. May need manual setting.")
                         self.global_model_input_shape_for_build = (1,10) if self.global_model_type == 'lstm' else (10,)


                # For num_classes from the last layer
                last_layer_config = config['layers'][-1]['config']
                if 'units' in last_layer_config:
                    self.global_model_num_classes = last_layer_config['units']
                else:
                    print("Server Warning: Could not infer num_classes from loaded model config. Defaulting to 2.")
                    self.global_model_num_classes = 2 # Default
            else:
                 print("Server Warning: Loaded model config has no layers. Using defaults for shape/classes.")
                 self.global_model_input_shape_for_build = (1,10) if self.global_model_type == 'lstm' else (10,)
                 self.global_model_num_classes = 2


            # Ensure model is compiled if not already
            if not self.global_model.optimizer:
                print("Server: Loaded model optimizer not found. Re-compiling with default Adam.")
                self.global_model.compile(optimizer=Adam(learning_rate=0.001),
                                          loss='sparse_categorical_crossentropy', metrics=['accuracy'])

            print(f"Server: Inferred from loaded model - Type: {self.global_model_type}, "
                  f"Build Input Shape: {self.global_model_input_shape_for_build}, Classes: {self.global_model_num_classes}")
            return True
        except Exception as e:
            print(f"Server: Error loading global model from {path}: {e}")
            import traceback
            traceback.print_exc()
            return False

    def save_history(self, path=None):
        if not self.aggregation_history: print("Server: No aggregation history."); return None
        if path is None:
            path = os.path.join(FL_DIR if 'FL_DIR' in globals() else ".",
                                f"fl_agg_hist_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json")
        try:
            with open(path, 'w') as f: json.dump(self.aggregation_history, f, indent=2)
            print(f"Server: Aggregation history saved to {path}")
            return path
        except Exception as e: print(f"Server: Error saving agg history: {e}"); return None

# --- Test Block for FederatedServer (Part 3: Evaluation & Saving) ---
if __name__ == "__main__" and 'google.colab' in sys.modules:
    print("\n--- Testing FederatedServer (Part 3: Evaluation & Saving) ---")

    if 'DATA_DIR' not in globals() or not os.path.exists(DATA_DIR) or \
       'MODEL_DIR' not in globals() or not os.path.exists(MODEL_DIR) or \
       'FL_DIR' not in globals() or not os.path.exists(FL_DIR):
        print("⚠️ DATA_DIR, MODEL_DIR or FL_DIR not found. Skipping test. Run Section 1 first.")
    elif 'FederatedClient' not in globals() or 'FederatedServer' not in globals():
        print("⚠️ FederatedClient or FederatedServer class not defined. Run previous sections.")
    else:
        server_eval_test = None
        if 'server_test_part2' in locals() and isinstance(server_test_part2, FederatedServer) and server_test_part2.global_model:
            server_eval_test = server_test_part2
            print("Using server instance from Part 2 test for Part 3 test.")
        else:
            print("Creating new server instance for Part 3 test (Part 2 server not found or no model).")
            server_eval_test = FederatedServer()
            client_for_init_eval = FederatedClient(client_id="init_client_for_eval_test_p3")
            dummy_data_path_eval = os.path.join(DATA_DIR, "server_client_dummy_data.csv")
            if os.path.exists(dummy_data_path_eval):
                client_for_init_eval.load_data(data_path=dummy_data_path_eval)
                if client_for_init_eval.data is not None:
                     server_eval_test.initialize_global_model(model_type='mlp', client_for_shape_details=client_for_init_eval)
                else: # Fallback if client data load failed
                     server_eval_test.initialize_global_model(model_type='mlp', input_shape=(2,), num_classes=2)
            else: # Fallback if dummy data CSV is missing
                print(f"Dummy data {dummy_data_path_eval} not found, initializing server model with defaults.")
                server_eval_test.initialize_global_model(model_type='mlp', input_shape=(2,), num_classes=2)

        if server_eval_test.global_model:
            print("\n-- Testing Global Model Evaluation --")
            eval_input_shape_for_build = server_eval_test.global_model_input_shape_for_build
            num_eval_samples = 50

            if server_eval_test.global_model_type == 'lstm':
                # eval_input_shape_for_build is (timesteps, features) e.g. (1, num_features)
                X_global_test = np.random.rand(num_eval_samples, eval_input_shape_for_build[0], eval_input_shape_for_build[1])
            else: # MLP
                # eval_input_shape_for_build is (features,) e.g. (num_features,)
                X_global_test = np.random.rand(num_eval_samples, eval_input_shape_for_build[0])

            y_global_test = np.random.randint(0, server_eval_test.global_model_num_classes, num_eval_samples)

            eval_results = server_eval_test.evaluate_global_model(X_global_test, y_global_test)
            if eval_results:
                print(f"Global model evaluation (on random data): Accuracy {eval_results['accuracy']:.4f}")

            print("\n-- Testing Save/Load Global Model --")
            saved_model_path = server_eval_test.save_global_model()
            if saved_model_path and os.path.exists(saved_model_path):
                server_new_load_test = FederatedServer()
                loaded = server_new_load_test.load_global_model(saved_model_path)
                if loaded and server_new_load_test.global_model:
                    print("Global model loaded into new server instance successfully.")
                    # Re-evaluate to check if loaded model works
                    re_eval_results = server_new_load_test.evaluate_global_model(X_global_test, y_global_test)
                    if re_eval_results:
                         print(f"Re-evaluation of loaded global model: Accuracy {re_eval_results['accuracy']:.4f}")

            print("\n-- Testing Save History --")
            if not server_eval_test.aggregation_history:
                server_eval_test.aggregation_history.append({'round':0, 'status':'dummy_for_save_test', 'successful_clients':[], 'client_data_sizes':[]})
            server_eval_test.save_history()
        else:
            print("Global model not initialized on server_eval_test. Skipping Part 3 tests.")

    print("\n--- End of FederatedServer (Part 3) Test ---")

print("\n✅ Section 7 (FederatedServer Class - Part 3: Evaluation & Saving) is ready.")


--- Testing FederatedServer (Part 3: Evaluation & Saving) ---
Creating new server instance for Part 3 test (Part 2 server not found or no model).
Server: Initialized global MLP model (Build Input Shape: (2,), Classes: 2).

-- Testing Global Model Evaluation --


  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


Server: Global Model Eval - Accuracy: 0.5000, Loss: 0.6926
Global model evaluation (on random data): Accuracy 0.5000

-- Testing Save/Load Global Model --
Server: Global model saved to /content/federated_ids_ai_project/models/fl_global_mlp_model_20250528_234345.h5
Server: Global model loaded from /content/federated_ids_ai_project/models/fl_global_mlp_model_20250528_234345.h5
Server: Inferred from loaded model - Type: mlp, Build Input Shape: (2,), Classes: 2
Global model loaded into new server instance successfully.
Server: Global Model Eval - Accuracy: 0.5000, Loss: 0.6926
Re-evaluation of loaded global model: Accuracy 0.5000

-- Testing Save History --
Server: Aggregation history saved to /content/federated_ids_ai_project/federated_outputs/fl_agg_hist_20250528_234345.json

--- End of FederatedServer (Part 3) Test ---

✅ Section 7 (FederatedServer Class - Part 3: Evaluation & Saving) is ready.


In [10]:
# Imports from Section 1 should still be in effect.
# FederatedClient class (all parts) and FederatedServer class (Parts 1 & 2) should be defined.

class FederatedServer:
    """
    Represents the central server in the federated learning system.
    Coordinates the training process across multiple clients.
    (Includes methods from Part 1 & 2, and adds new methods from Part 3)
    """

    def __init__(self):
        self.clients = {}
        self.global_model = None
        self.global_model_type = None
        self.global_model_input_shape_for_build = None
        self.global_model_num_classes = None
        self.aggregation_history = []
        # print("Federated Server: Initialized.")

    def add_client(self, client):
        if not isinstance(client, FederatedClient): return
        if client.client_id in self.clients: return
        self.clients[client.client_id] = client

    def remove_client(self, client_id):
        if client_id in self.clients:
            del self.clients[client_id]; return True
        return False

    def initialize_global_model(self, model_type='mlp',
                                intended_feature_columns=None,
                                intended_target_column=None,
                                client_for_shape_details=None,
                                default_input_shape=None,
                                default_num_classes=2):
        self.global_model_type = model_type.lower()
        final_input_shape_for_build = default_input_shape
        final_num_classes = default_num_classes

        if client_for_shape_details and client_for_shape_details.data is not None and \
           intended_feature_columns and intended_target_column:
            preproc_result = client_for_shape_details.preprocess_data(
                feature_columns=intended_feature_columns,
                target_column=intended_target_column,
                reshape_for_lstm=False
            )
            if preproc_result:
                X_train_sample, _, _, _, preproc_objects = preproc_result
                if X_train_sample is not None and X_train_sample.shape[0] > 0:
                    num_features = X_train_sample.shape[1]
                    final_input_shape_for_build = (1, num_features) if self.global_model_type == 'lstm' else (num_features,)
                    final_num_classes = preproc_objects['num_classes']

        if final_input_shape_for_build is None:
            final_input_shape_for_build = (1, 10) if self.global_model_type == 'lstm' else (10,)
        if final_num_classes is None:
            final_num_classes = 2

        self.global_model_input_shape_for_build = final_input_shape_for_build
        self.global_model_num_classes = final_num_classes

        temp_client_for_build = FederatedClient(client_id="server_model_builder_temp")
        if self.global_model_type == 'lstm':
            self.global_model = temp_client_for_build.build_lstm_model(self.global_model_input_shape_for_build, self.global_model_num_classes)
        else: # MLP
            self.global_model = temp_client_for_build.build_mlp_model(self.global_model_input_shape_for_build, self.global_model_num_classes)
        del temp_client_for_build

        print(f"Server: Initialized global {self.global_model_type.upper()} model (Build Input Shape: {self.global_model_input_shape_for_build}, Classes: {self.global_model_num_classes}).")

    def federated_averaging(self, client_model_weights_list, client_data_sizes=None):
        if not client_model_weights_list: return None
        avg_weights = [np.zeros_like(w) for w in client_model_weights_list[0]]
        total_data_size = 0
        if client_data_sizes:
            if len(client_model_weights_list) != len(client_data_sizes): client_data_sizes = None
            else: total_data_size = sum(client_data_sizes)

        if total_data_size == 0 or client_data_sizes is None:
            num_clients_with_weights = len(client_model_weights_list)
            for client_weights in client_model_weights_list:
                for i, layer_weights in enumerate(client_weights): avg_weights[i] += layer_weights / num_clients_with_weights
        else:
            for idx, client_weights in enumerate(client_model_weights_list):
                weight_factor = client_data_sizes[idx] / total_data_size
                for i, layer_weights in enumerate(client_weights): avg_weights[i] += layer_weights * weight_factor
        return avg_weights

    def train_round(self, round_number, num_selected_clients=None, epochs_per_client=5, batch_size_per_client=32, feature_columns_for_clients=None, target_column_for_clients=None):
        if self.global_model is None or not self.clients: return None
        available_client_ids = list(self.clients.keys())
        if not available_client_ids: return None
        if num_selected_clients is None: num_selected_clients = len(available_client_ids)
        selected_client_ids = random.sample(available_client_ids, min(num_selected_clients, len(available_client_ids)))
        if not selected_client_ids: return None

        print(f"\n--- Server: Starting FL Round {round_number} with {len(selected_client_ids)} clients: {selected_client_ids} ---")
        current_global_weights = self.global_model.get_weights()
        collected_client_weights, client_data_sizes_for_round, successful_clients_this_round = [], [], []

        for client_id in selected_client_ids:
            client = self.clients[client_id]
            rebuild_client_model = False
            if client.model is None: rebuild_client_model = True
            else:
                try:
                    client_config = client.model.get_config()
                    client_input_shape_from_config = client_config['layers'][0]['config'].get('batch_input_shape')[1:]
                    client_output_units_from_config = client_config['layers'][-1]['config'].get('units')
                    if client_input_shape_from_config != self.global_model_input_shape_for_build or \
                       client_output_units_from_config != self.global_model_num_classes:
                        rebuild_client_model = True
                except Exception: rebuild_client_model = True # Rebuild if error checking structure

            if rebuild_client_model:
                # print(f"Client {client.client_id}: Rebuilding model structure.") # Verbose
                temp_builder = FederatedClient(client_id="temp_builder_round")
                client.model = temp_builder.build_lstm_model(self.global_model_input_shape_for_build, self.global_model_num_classes) if self.global_model_type == 'lstm' \
                    else temp_builder.build_mlp_model(self.global_model_input_shape_for_build, self.global_model_num_classes)
                del temp_builder

            client.update_model(current_global_weights)
            training_result = client.train_local_model(
                model_type=self.global_model_type, epochs=epochs_per_client,
                batch_size=batch_size_per_client, verbose=0,
                feature_columns=feature_columns_for_clients, target_column=target_column_for_clients)
            if training_result and 'model_weights' in training_result:
                collected_client_weights.append(training_result['model_weights'])
                client_data_sizes_for_round.append(training_result['num_samples'])
                successful_clients_this_round.append(client_id)
                # print(f"Client {client.client_id}: Train OK. Acc: {training_result['training_summary']['test_accuracy']:.4f}")
        if not collected_client_weights: print("Server: No weights collected."); return {'round':round_number, 'status':'failed_no_weights', 'successful_clients':[]}
        new_global_weights = self.federated_averaging(collected_client_weights, client_data_sizes_for_round)
        if new_global_weights: self.global_model.set_weights(new_global_weights); # print("Server: Global model updated.") # Verbose
        round_summary = {'round':round_number, 'status':'completed', 'num_selected':len(selected_client_ids), 'successful_clients':successful_clients_this_round, 'client_data_sizes':client_data_sizes_for_round}
        self.aggregation_history.append(round_summary)
        return round_summary

    def evaluate_global_model(self, X_test_global, y_test_global):
        if self.global_model is None: print("Server: Global model not init."); return None
        if X_test_global is None or y_test_global is None or (isinstance(X_test_global, np.ndarray) and X_test_global.size == 0):
            print("Server: No global test data provided."); return None
        X_test_to_eval = X_test_global
        is_lstm_model = any(isinstance(layer, LSTM) for layer in self.global_model.layers)
        if is_lstm_model and len(X_test_to_eval.shape) == 2:
             if self.global_model_input_shape_for_build and len(self.global_model_input_shape_for_build) == 2: # (timesteps, features)
                timesteps, num_features = self.global_model_input_shape_for_build[0], self.global_model_input_shape_for_build[1]
                if X_test_to_eval.shape[1] == num_features:
                    X_test_to_eval = X_test_to_eval.reshape(X_test_to_eval.shape[0], timesteps, num_features)
                else: return None
        loss, accuracy = self.global_model.evaluate(X_test_to_eval, y_test_global, verbose=0)
        y_pred_probs = self.global_model.predict(X_test_to_eval, verbose=0); y_pred_classes = np.argmax(y_pred_probs, axis=1)
        report_dict = {}; conf_matrix_list = []
        try:
            labels = np.union1d(np.unique(y_test_global), np.unique(y_pred_classes))
            report_dict = classification_report(y_test_global,y_pred_classes,labels=labels if len(labels)>0 else None,output_dict=True,zero_division=0)
            conf_matrix_list = confusion_matrix(y_test_global,y_pred_classes,labels=labels if len(labels)>0 else None).tolist()
        except Exception as e: print(f"Server: Could not gen report/matrix: {e}")
        eval_result = {'timestamp': datetime.now().isoformat(), 'loss': float(loss), 'accuracy': float(accuracy), 'classification_report': report_dict, 'confusion_matrix': conf_matrix_list}
        # print(f"Server: Global Model Eval - Accuracy: {accuracy:.4f}, Loss: {loss:.4f}") # Verbose
        return eval_result

    def save_global_model(self, path=None):
        if self.global_model is None: print("Server: No global model."); return None
        if path is None:
            path = os.path.join(MODEL_DIR if 'MODEL_DIR' in globals() else ".", f"fl_global_{self.global_model_type}_model_{datetime.now().strftime('%Y%m%d_%H%M%S')}.h5")
        try: self.global_model.save(path); print(f"Server: Global model saved to {path}"); return path
        except Exception as e: print(f"Server: Error saving global model: {e}"); return None

    def load_global_model(self, path):
        try:
            self.global_model = tf.keras.models.load_model(path)
            print(f"Server: Global model loaded from {path}")
            self.global_model_type = 'lstm' if any(isinstance(layer, LSTM) for layer in self.global_model.layers) else 'mlp'
            config = self.global_model.get_config()
            if 'layers' in config and config['layers']:
                first_layer_config = config['layers'][0]['config']
                if 'batch_input_shape' in first_layer_config: self.global_model_input_shape_for_build = first_layer_config['batch_input_shape'][1:]
                elif hasattr(self.global_model, 'input_shape') and isinstance(self.global_model.input_shape, tuple) and len(self.global_model.input_shape) > 1: self.global_model_input_shape_for_build = self.global_model.input_shape[1:]
                else: self.global_model_input_shape_for_build = None
                last_layer_config = config['layers'][-1]['config']
                if 'units' in last_layer_config: self.global_model_num_classes = last_layer_config['units']
                else: self.global_model_num_classes = None
            else: self.global_model_input_shape_for_build = None; self.global_model_num_classes = None;
            if not self.global_model.optimizer: self.global_model.compile(optimizer=Adam(learning_rate=0.001), loss='sparse_categorical_crossentropy', metrics=['accuracy'])
            # print(f"Server: Loaded model details - Type: {self.global_model_type}, Build Input: {self.global_model_input_shape_for_build}, Classes: {self.global_model_num_classes}")
            return True
        except Exception as e: print(f"Server: Error loading global model from {path}: {e}"); import traceback; traceback.print_exc(); return False

    def save_history(self, path=None):
        if not self.aggregation_history: print("Server: No aggregation history."); return None
        if path is None:
            path = os.path.join(FL_DIR if 'FL_DIR' in globals() else ".", f"fl_agg_hist_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json")
        try:
            with open(path, 'w') as f: json.dump(self.aggregation_history, f, indent=2)
            print(f"Server: Aggregation history saved to {path}")
            return path
        except Exception as e: print(f"Server: Error saving agg history: {e}"); return None

# --- Test Block for FederatedServer (Part 3: Evaluation & Saving) ---
if __name__ == "__main__" and 'google.colab' in sys.modules:
    print("\n--- Testing FederatedServer (Part 3: Evaluation & Saving) ---")

    if 'DATA_DIR' not in globals() or not os.path.exists(DATA_DIR) or \
       'MODEL_DIR' not in globals() or not os.path.exists(MODEL_DIR) or \
       'FL_DIR' not in globals() or not os.path.exists(FL_DIR):
        print("⚠️ DATA_DIR, MODEL_DIR or FL_DIR not found. Skipping test. Run Section 1 first.")
    elif 'FederatedClient' not in globals() or 'FederatedServer' not in globals():
        print("⚠️ FederatedClient or FederatedServer class not defined. Run previous sections defining these classes.")
    else:
        server_eval_test = None
        if 'server_test_part2' in locals() and isinstance(server_test_part2, FederatedServer) and server_test_part2.global_model:
            server_eval_test = server_test_part2
            print("Using server instance from Part 2 test for Part 3 test.")
        else:
            print("Creating new server instance for Part 3 test.")
            server_eval_test = FederatedServer()
            client_for_init_eval = FederatedClient(client_id="init_client_for_eval_test_p3")
            dummy_data_path_eval = os.path.join(DATA_DIR, "server_client_dummy_data.csv")
            dummy_feature_cols_for_init = [f'f{i+1}_numeric' for i in range(2)]
            dummy_target_col_for_init = 'target_label'

            if not os.path.exists(dummy_data_path_eval):
                pd.DataFrame({ # Corrected Indentation for this block
                    'f1_numeric': np.random.rand(50), 'f2_numeric': np.random.rand(50),
                    'target_label': np.random.choice(['ClassA', 'ClassB'], 50)
                }).to_csv(dummy_data_path_eval, index=False)
                print(f"Created dummy data for server test: {dummy_data_path_eval}")

            if os.path.exists(dummy_data_path_eval):
                client_for_init_eval.load_data(data_path=dummy_data_path_eval)
                if client_for_init_eval.data is not None:
                     server_eval_test.initialize_global_model(
                         model_type='mlp',
                         intended_feature_columns=dummy_feature_cols_for_init,
                         intended_target_column=dummy_target_col_for_init,
                         client_for_shape_details=client_for_init_eval
                     )
                else:
                     server_eval_test.initialize_global_model(model_type='mlp', default_input_shape=(len(dummy_feature_cols_for_init),), default_num_classes=2)
            else:
                print(f"Dummy data {dummy_data_path_eval} not found, initializing server model with defaults.")
                server_eval_test.initialize_global_model(model_type='mlp', default_input_shape=(len(dummy_feature_cols_for_init),), default_num_classes=2)

        if server_eval_test.global_model:
            print("\n-- Testing Global Model Evaluation --")
            eval_input_shape_for_build = server_eval_test.global_model_input_shape_for_build
            num_eval_samples = 50

            if server_eval_test.global_model_type == 'lstm':
                 X_global_test = np.random.rand(num_eval_samples, eval_input_shape_for_build[0], eval_input_shape_for_build[1])
            else: # MLP
                 X_global_test = np.random.rand(num_eval_samples, eval_input_shape_for_build[0])
            y_global_test = np.random.randint(0, server_eval_test.global_model_num_classes, num_eval_samples)

            eval_results = server_eval_test.evaluate_global_model(X_global_test, y_global_test)
            if eval_results: print(f"Global model evaluation (on random data): Accuracy {eval_results['accuracy']:.4f}")

            print("\n-- Testing Save/Load Global Model --")
            saved_model_path = server_eval_test.save_global_model()
            if saved_model_path and os.path.exists(saved_model_path):
                server_new_load_test = FederatedServer()
                loaded = server_new_load_test.load_global_model(saved_model_path)
                if loaded and server_new_load_test.global_model:
                    print("Global model loaded into new server instance successfully.")
                    re_eval_results = server_new_load_test.evaluate_global_model(X_global_test, y_global_test)
                    if re_eval_results: print(f"Re-evaluation of loaded global model: Accuracy {re_eval_results['accuracy']:.4f}")

            print("\n-- Testing Save History --")
            if not server_eval_test.aggregation_history:
                server_eval_test.aggregation_history.append({'round':0, 'status':'dummy_for_save_test'})
            server_eval_test.save_history()
        else:
            print("Global model not initialized on server_eval_test. Skipping Part 3 tests.")
    print("\n--- End of FederatedServer (Part 3) Test ---")

print("\n✅ Section 7 (FederatedServer Class - Part 3: Evaluation & Saving) is ready.")


--- Testing FederatedServer (Part 3: Evaluation & Saving) ---
Creating new server instance for Part 3 test.
Server: Initialized global MLP model (Build Input Shape: (2,), Classes: 2).

-- Testing Global Model Evaluation --


  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


Global model evaluation (on random data): Accuracy 0.4400

-- Testing Save/Load Global Model --
Server: Global model saved to /content/federated_ids_ai_project/models/fl_global_mlp_model_20250528_234346.h5
Server: Global model loaded from /content/federated_ids_ai_project/models/fl_global_mlp_model_20250528_234346.h5
Global model loaded into new server instance successfully.
Re-evaluation of loaded global model: Accuracy 0.4400

-- Testing Save History --
Server: Aggregation history saved to /content/federated_ids_ai_project/federated_outputs/fl_agg_hist_20250528_234347.json

--- End of FederatedServer (Part 3) Test ---

✅ Section 7 (FederatedServer Class - Part 3: Evaluation & Saving) is ready.


In [11]:
# Imports from Section 1 should still be in effect.
# FederatedClient and FederatedServer classes should be defined from previous sections.

def simulate_federated_learning(num_clients=5, num_rounds=5, model_type='mlp',
                                data_path=None, feature_columns=None, target_column=None,
                                epochs_per_client=5, batch_size_per_client=32,
                                clients_per_round=None):
    """
    Simulate a federated learning scenario.
    """
    print(f"\n--- Starting Federated Learning Simulation ---")
    print(f"Config: {num_clients} clients, {num_rounds} rounds, Model: {model_type.upper()}")
    print(f"Data Path: {data_path if data_path else 'Synthetic Data'}")

    server = FederatedServer()

    client_data_list = []
    if data_path and os.path.exists(data_path):
        print(f"Loading and partitioning data from: {data_path}")
        try:
            if data_path.endswith('.csv'):
                full_df = pd.read_csv(data_path, low_memory=False)
            elif data_path.endswith(('.json', '.jsonl')):
                full_df = pd.read_json(data_path, lines=data_path.endswith('.jsonl'))
            else:
                raise ValueError(f"Unsupported data file type: {data_path}")

            if full_df.empty or len(full_df) < num_clients:
                print(f"⚠️ Warning: Dataset at {data_path} is empty or too small for {num_clients} clients. Using synthetic data.")
                data_path = None
            else:
                num_total_samples = len(full_df)
                samples_per_client = num_total_samples // num_clients
                for i in range(num_clients):
                    start_idx = i * samples_per_client
                    end_idx = (i + 1) * samples_per_client if i < num_clients - 1 else num_total_samples
                    client_df = full_df.iloc[start_idx:end_idx].copy()
                    if not client_df.empty:
                        client_data_list.append(client_df)

                if not client_data_list:
                    print("⚠️ All client data partitions were empty. Falling back to synthetic data.")
                    data_path = None
        except Exception as e:
            print(f"❌ Error loading or partitioning real data: {e}. Falling back to synthetic data.")
            data_path = None

    if not client_data_list:
        print("Generating synthetic data for clients...")
        client_data_list = []
        num_synthetic_features = len(feature_columns) if feature_columns and isinstance(feature_columns, list) else 10
        cols = feature_columns if feature_columns and isinstance(feature_columns, list) else [f"feature_{j}" for j in range(num_synthetic_features)]

        for i in range(num_clients):
            n_samples = random.randint(200, 500)
            X_synthetic = np.random.rand(n_samples, num_synthetic_features)
            y_synthetic = np.random.randint(0, 2, size=n_samples)

            df_synthetic = pd.DataFrame(X_synthetic, columns=cols)
            df_synthetic[target_column if target_column else 'label'] = y_synthetic
            client_data_list.append(df_synthetic)

    actual_num_clients = len(client_data_list)
    if actual_num_clients == 0: print("❌ No data for any client. Aborting."); return None

    for i in range(actual_num_clients):
        client = FederatedClient(client_id=f"fl_client_{i}", data=client_data_list[i])
        server.add_client(client)

    first_client_with_data = next((c for c in server.clients.values() if c.data is not None and not c.data.empty), None)
    if not first_client_with_data: print("❌ No client has data for model init. Aborting."); return None

    server.initialize_global_model(
        model_type=model_type,
        intended_feature_columns=feature_columns, # Pass intended features for shape inference
        intended_target_column=target_column,     # Pass intended target for shape inference
        client_for_shape_details=first_client_with_data
    )
    if server.global_model is None: print("❌ Failed to initialize global model. Aborting."); return None

    for r in range(1, num_rounds + 1):
        server.train_round(
            round_number=r,
            num_selected_clients=clients_per_round if clients_per_round else actual_num_clients,
            epochs_per_client=epochs_per_client,
            batch_size_per_client=batch_size_per_client,
            feature_columns_for_clients=feature_columns, # Pass to ensure clients use correct features
            target_column_for_clients=target_column     # Pass to ensure clients use correct target
        )

    print("\n--- Performing Final Global Model Evaluation ---")
    all_X_test_scaled, all_y_test = [], []
    final_eval_results = None

    for client_id, client in server.clients.items():
        if client.data is not None and not client.data.empty:
            preproc_res = client.preprocess_data(
                feature_columns=feature_columns,
                target_column=target_column,
                reshape_for_lstm=(server.global_model_type == 'lstm')
            )
            if preproc_res:
                _, X_test_client, _, y_test_client, _ = preproc_res
                if len(X_test_client) > 0:
                    all_X_test_scaled.append(X_test_client)
                    all_y_test.append(y_test_client)

    if all_X_test_scaled and all_y_test:
        try:
            global_X_test = np.concatenate(all_X_test_scaled, axis=0)
            global_y_test = np.concatenate(all_y_test, axis=0)
            if len(global_X_test) > 0:
                 final_eval_results = server.evaluate_global_model(global_X_test, global_y_test)
            else: print("No test data for final global model evaluation after concatenation.")
        except ValueError as ve:
            print(f"Error concatenating client test data for global evaluation: {ve}")
    else:
        print("No client test data for final global model evaluation.")

    model_save_path = server.save_global_model()
    history_save_path = server.save_history()

    print("\n--- Federated Learning Simulation Completed ---")
    if final_eval_results: print(f"Final Global Model Accuracy: {final_eval_results.get('accuracy', 'N/A'):.4f}")

    return {"num_clients": actual_num_clients, "num_rounds": num_rounds, "model_type": model_type, "final_evaluation": final_eval_results, "global_model_path": model_save_path, "aggregation_history_path": history_save_path, "aggregation_history": server.aggregation_history}


def compare_federated_vs_centralized(data_path, feature_columns, target_column,
                                     num_clients=5, num_rounds=5, model_type='mlp',
                                     epochs_per_client=5, batch_size_per_client=32,
                                     centralized_epochs=20, centralized_batch_size=32):
    print("\n--- Comparing Federated Learning vs. Centralized Training ---")
    if not data_path or not os.path.exists(data_path):
        print(f"❌ Data path '{data_path}' not found. Comparison aborted.")
        return None

    print("\n--- Running Federated Learning for Comparison ---")
    fl_sim_summary = simulate_federated_learning(
        num_clients=num_clients, num_rounds=num_rounds, model_type=model_type,
        data_path=data_path, feature_columns=feature_columns, target_column=target_column,
        epochs_per_client=epochs_per_client, batch_size_per_client=batch_size_per_client
    )
    fl_accuracy = 0.0
    if fl_sim_summary and fl_sim_summary.get('final_evaluation') and isinstance(fl_sim_summary['final_evaluation'], dict) :
        fl_accuracy = fl_sim_summary['final_evaluation'].get('accuracy', 0.0)

    print("\n--- Running Centralized Learning for Comparison ---")
    cen_accuracy = 0.0
    try:
        full_df_cen = pd.read_csv(data_path, low_memory=False)
        full_df_cen.replace([np.inf, -np.inf], np.nan, inplace=True)

        # Ensure all specified feature_columns and target_column actually exist before trying to use them
        actual_feature_cols_cen = [col for col in feature_columns if col in full_df_cen.columns]
        if len(actual_feature_cols_cen) != len(feature_columns):
            print(f"Warning: Not all specified feature_columns found in centralized dataset. Using: {actual_feature_cols_cen}")
        if target_column not in full_df_cen.columns:
            raise ValueError(f"Target column '{target_column}' not found in centralized dataset for comparison.")

        cols_to_check_dropna = actual_feature_cols_cen + [target_column]
        full_df_cen.dropna(subset=cols_to_check_dropna, inplace=True)
        if full_df_cen.empty: raise ValueError("Centralized dataset empty after dropna.")

        X_cen = full_df_cen[actual_feature_cols_cen]
        y_cen = full_df_cen[target_column]

        X_numeric_cen = X_cen.select_dtypes(include=np.number)
        if X_numeric_cen.shape[1] < X_cen.shape[1]:
            print(f"Centralized: Warning - Dropping non-numeric columns: {X_cen.select_dtypes(exclude=np.number).columns.tolist()}")
        X_cen = X_numeric_cen
        if X_cen.empty : raise ValueError("No numeric features left for centralized training.")

        label_encoder_cen = LabelEncoder(); y_encoded_cen = label_encoder_cen.fit_transform(y_cen)
        num_classes_cen = len(label_encoder_cen.classes_)
        if num_classes_cen < 2: raise ValueError("Centralized training needs at least 2 classes in target.")

        X_train_cen, X_test_cen, y_train_cen, y_test_cen = train_test_split(
            X_cen, y_encoded_cen, test_size=0.2, random_state=42, stratify=y_encoded_cen
        )

        scaler_cen = StandardScaler(); X_train_scaled_cen = scaler_cen.fit_transform(X_train_cen); X_test_scaled_cen = scaler_cen.transform(X_test_cen)

        temp_builder_cen = FederatedClient("cen_builder_temp")
        input_shape_cen_build = None
        if model_type.lower() == 'lstm':
            X_train_scaled_cen = X_train_scaled_cen.reshape(X_train_scaled_cen.shape[0], 1, X_train_scaled_cen.shape[1])
            X_test_scaled_cen = X_test_scaled_cen.reshape(X_test_scaled_cen.shape[0], 1, X_test_scaled_cen.shape[1])
            input_shape_cen_build = (X_train_scaled_cen.shape[1], X_train_scaled_cen.shape[2])
            centralized_model = temp_builder_cen.build_lstm_model(input_shape_cen_build, num_classes_cen)
        else:
            input_shape_cen_build = (X_train_scaled_cen.shape[1],)
            centralized_model = temp_builder_cen.build_mlp_model(input_shape_cen_build, num_classes_cen)
        del temp_builder_cen

        print(f"Centralized {model_type.upper()} model built. Input: {input_shape_cen_build}, Classes: {num_classes_cen}")
        centralized_model.fit(X_train_scaled_cen, y_train_cen, epochs=centralized_epochs, batch_size=centralized_batch_size, verbose=0,
                              validation_data=(X_test_scaled_cen, y_test_cen),
                              callbacks=[EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True)])

        _, cen_accuracy = centralized_model.evaluate(X_test_scaled_cen, y_test_cen, verbose=0)
        print(f"Centralized Model Test Accuracy: {cen_accuracy:.4f}")
    except Exception as e:
        print(f"❌ Error during centralized training: {e}"); import traceback; traceback.print_exc()

    print("\n--- Comparison Summary ---")
    print(f"Federated Learning ({model_type.upper()}) Final Global Accuracy: {fl_accuracy:.4f}")
    print(f"Centralized Learning ({model_type.upper()}) Test Accuracy: {cen_accuracy:.4f}")

    return {"federated_accuracy": fl_accuracy, "centralized_accuracy": cen_accuracy}

# --- Test Block for Simulation Functions ---
if __name__ == "__main__" and 'google.colab' in sys.modules:
    print("\n--- Testing Simulation Functions ---")
    if 'DATA_DIR' not in globals() or 'MODEL_DIR' not in globals() or 'FL_DIR' not in globals():
        print("⚠️ DATA_DIR, MODEL_DIR, or FL_DIR not defined. Skipping. Run Section 1 first.")
    else:
        REAL_DATASET_PATH = "/content/drive/MyDrive/Colab Notebooks/datasets/ML-EdgeIIoT-dataset.csv"
        REAL_FEATURE_COLUMNS = [
            'arp.hw.size', 'http.content_length', 'http.response', 'http.tls_port',
            'tcp.ack_raw', 'tcp.checksum', 'tcp.connection.fin', 'tcp.connection.rst',
            'tcp.connection.syn', 'tcp.connection.synack', 'tcp.dstport', 'tcp.flags.ack',
            'tcp.len', 'udp.stream', 'udp.time_delta', 'dns.qry.qu', 'dns.qry.type',
            'dns.retransmission', 'dns.retransmit_request', 'dns.retransmit_request_in',
            'mqtt.conflag.cleansess', 'mqtt.hdrflags', 'mqtt.len', 'mqtt.msg_decoded_as',
            'mbtcp.len', 'mbtcp.trans_id', 'mbtcp.unit_id'
        ]
        REAL_TARGET_COLUMN = 'Attack_label'

        if os.path.exists(REAL_DATASET_PATH):
            print(f"\n--- Test 1: Running simulate_federated_learning with '{REAL_DATASET_PATH}' ---")
            fl_sim_results = simulate_federated_learning(
                num_clients=2, num_rounds=1, model_type='mlp',
                data_path=REAL_DATASET_PATH, feature_columns=REAL_FEATURE_COLUMNS, target_column=REAL_TARGET_COLUMN,
                epochs_per_client=1, batch_size_per_client=32, clients_per_round=2)
            if fl_sim_results: print("FL simulation with real data completed.")
        else:
            print(f"⚠️ Dataset {REAL_DATASET_PATH} not found. Running FL with synthetic data.")
            simulate_federated_learning(num_clients=2, num_rounds=1, model_type='mlp', epochs_per_client=1,
                                        feature_columns=['f1','f2','f3'], target_column='label') # Pass example feature/target for synthetic

        print(f"\n--- Test 2: Running compare_federated_vs_centralized (using synthetic data for test robustness) ---")
        synthetic_compare_path = os.path.join(DATA_DIR, "compare_synthetic_data.csv")
        synthetic_features = [f'synth_feat_{k}' for k in range(5)]
        pd.DataFrame(
            np.random.rand(200, len(synthetic_features)), columns=synthetic_features
        ).assign(target=np.random.randint(0,2,200)).to_csv(synthetic_compare_path, index=False)

        compare_results = compare_federated_vs_centralized(
            data_path=synthetic_compare_path, feature_columns=synthetic_features, target_column='target',
            num_clients=2, num_rounds=1, model_type='mlp', epochs_per_client=1, centralized_epochs=2)
        if compare_results: print("Comparison simulation completed.")
    print("\n--- End of Simulation Functions Test ---")

print("\n✅ Section 8 (Federated Learning - Simulation Functions) is ready.")


--- Testing Simulation Functions ---
⚠️ Dataset /content/drive/MyDrive/Colab Notebooks/datasets/ML-EdgeIIoT-dataset.csv not found. Running FL with synthetic data.

--- Starting Federated Learning Simulation ---
Config: 2 clients, 1 rounds, Model: MLP
Data Path: Synthetic Data
Generating synthetic data for clients...
Server: Initialized global MLP model (Build Input Shape: (3,), Classes: 2).

--- Server: Starting FL Round 1 with 2 clients: ['fl_client_1', 'fl_client_0'] ---


  super().__init__(activity_regularizer=activity_regularizer, **kwargs)
  super().__init__(activity_regularizer=activity_regularizer, **kwargs)
  super().__init__(activity_regularizer=activity_regularizer, **kwargs)



--- Performing Final Global Model Evaluation ---


  super().__init__(activity_regularizer=activity_regularizer, **kwargs)
  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


Server: Global model saved to /content/federated_ids_ai_project/models/fl_global_mlp_model_20250528_234353.h5
Server: Aggregation history saved to /content/federated_ids_ai_project/federated_outputs/fl_agg_hist_20250528_234353.json

--- Federated Learning Simulation Completed ---
Final Global Model Accuracy: 0.4872

--- Test 2: Running compare_federated_vs_centralized (using synthetic data for test robustness) ---

--- Comparing Federated Learning vs. Centralized Training ---

--- Running Federated Learning for Comparison ---

--- Starting Federated Learning Simulation ---
Config: 2 clients, 1 rounds, Model: MLP
Data Path: /content/federated_ids_ai_project/data/compare_synthetic_data.csv
Loading and partitioning data from: /content/federated_ids_ai_project/data/compare_synthetic_data.csv
Server: Initialized global MLP model (Build Input Shape: (5,), Classes: 2).

--- Server: Starting FL Round 1 with 2 clients: ['fl_client_1', 'fl_client_0'] ---


  super().__init__(activity_regularizer=activity_regularizer, **kwargs)



--- Performing Final Global Model Evaluation ---


  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


Server: Global model saved to /content/federated_ids_ai_project/models/fl_global_mlp_model_20250528_234359.h5
Server: Aggregation history saved to /content/federated_ids_ai_project/federated_outputs/fl_agg_hist_20250528_234359.json

--- Federated Learning Simulation Completed ---
Final Global Model Accuracy: 0.5250

--- Running Centralized Learning for Comparison ---
Centralized MLP model built. Input: (5,), Classes: 2
Centralized Model Test Accuracy: 0.5000

--- Comparison Summary ---
Federated Learning (MLP) Final Global Accuracy: 0.5250
Centralized Learning (MLP) Test Accuracy: 0.5000
Comparison simulation completed.

--- End of Simulation Functions Test ---

✅ Section 8 (Federated Learning - Simulation Functions) is ready.


In [12]:
# Imports from Section 1 should still be in effect.
# All classes (FederatedClient, FederatedServer) and simulation functions
# (simulate_federated_learning, compare_federated_vs_centralized) should be defined
# from previous sections.
# Global directories (BASE_DIR, DATA_DIR, MODEL_DIR, FL_DIR) should also be defined.

def main():
    """
    Main function to initialize and run the IDS system for federated learning.
    Handles argument parsing (simulated for Colab) and orchestrates the simulation.
    """
    class SimulatedArgs:
        def __init__(self):
            self.data = "/content/drive/MyDrive/Colab Notebooks/datasets/ML-EdgeIIoT-dataset.csv"
            self.feature_columns = [
                'arp.hw.size', 'http.content_length', 'http.response', 'http.tls_port',
                'tcp.ack_raw', 'tcp.checksum', 'tcp.connection.fin', 'tcp.connection.rst',
                'tcp.connection.syn', 'tcp.connection.synack', 'tcp.dstport', 'tcp.flags.ack',
                'tcp.len', 'udp.stream', 'udp.time_delta', 'dns.qry.qu', 'dns.qry.type',
                'dns.retransmission', 'dns.retransmit_request', 'dns.retransmit_request_in',
                'mqtt.conflag.cleansess', 'mqtt.hdrflags', 'mqtt.len', 'mqtt.msg_decoded_as',
                'mbtcp.len', 'mbtcp.trans_id', 'mbtcp.unit_id'
            ]
            self.target_column = 'Attack_label'
            self.clients = 3
            self.rounds = 2
            self.epochs = 2
            self.batch_size = 32
            self.model = "mlp"
            self.clients_per_round = 2
            self.centralized_epochs = 5
            self.compare = False

    args = SimulatedArgs()

    print("--- Federated Learning System Orchestration ---")
    print(f"Configuration: Dataset Path='{args.data}'")
    print(f"               Clients={args.clients}, Rounds={args.rounds}, Epochs/Client={args.epochs}, Model='{args.model.upper()}'")
    if args.compare:
        print("Mode: Comparing Federated vs. Centralized Learning")
    else:
        print("Mode: Running Federated Learning Simulation Only")

    if args.data and not os.path.exists(args.data): # Check if args.data is not None before os.path.exists
        print(f"⚠️ WARNING: Dataset not found at '{args.data}'. "
              "Simulation will proceed with internally generated SYNTHETIC data for clients.")
    elif args.data:
        print(f"✅ Using dataset: {args.data}")
    else: # args.data is None
        print(f"ℹ️ No dataset path provided. Simulation will use SYNTHETIC data.")


    if args.compare:
        results = compare_federated_vs_centralized(
            data_path=args.data,
            feature_columns=args.feature_columns,
            target_column=args.target_column,
            num_clients=args.clients,
            num_rounds=args.rounds,
            model_type=args.model,
            epochs_per_client=args.epochs,
            batch_size_per_client=args.batch_size,
            centralized_epochs=args.centralized_epochs,
        )
    else:
        results = simulate_federated_learning(
            num_clients=args.clients,
            num_rounds=args.rounds,
            model_type=args.model,
            data_path=args.data,
            feature_columns=args.feature_columns,
            target_column=args.target_column,
            epochs_per_client=args.epochs,
            batch_size_per_client=args.batch_size,
            clients_per_round=args.clients_per_round
        )

    if results:
        results_filename = f"fl_results_{args.model}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
        results_path = os.path.join(FL_DIR, results_filename)
        try:
            with open(results_path, 'w') as f:
                class NpEncoder(json.JSONEncoder):
                    def default(self, obj):
                        if isinstance(obj, np.integer): return int(obj)
                        if isinstance(obj, np.floating): return float(obj)
                        if isinstance(obj, np.ndarray): return obj.tolist()
                        return super(NpEncoder, self).default(obj)
                json.dump(results, f, indent=2, cls=NpEncoder)
            print(f"\n📊 Comprehensive simulation results saved to: {results_path}")
        except Exception as e:
            print(f"❌ Error saving comprehensive results: {e}")
    else:
        print("\nℹ️ Simulation did not produce results to save (possibly due to an earlier error or empty data).")

    print("\n🏁 Federated Learning Main Execution Finished.")

if __name__ == "__main__" and 'google.colab' in sys.modules:
    print("\n--- Running main() for Federated Learning Demonstration ---")
    essential_items = [
        'FederatedClient', 'FederatedServer',
        'simulate_federated_learning', 'compare_federated_vs_centralized',
        'BASE_DIR', 'DATA_DIR', 'MODEL_DIR', 'FL_DIR'
    ]
    all_defined = True
    for item_name in essential_items:
        if item_name not in globals():
            print(f"🔴 CRITICAL ERROR for main(): Required item '{item_name}' is not defined. "
                  "Ensure all previous code sections (1-8) ran successfully.")
            all_defined = False
            break

    if all_defined:
        main()
    else:
        print("\n❌ Main function execution aborted due to missing definitions.")

print("\n✅ Section 9 (Federated Learning - Main Execution Block) is ready.")


--- Running main() for Federated Learning Demonstration ---
--- Federated Learning System Orchestration ---
Configuration: Dataset Path='/content/drive/MyDrive/Colab Notebooks/datasets/ML-EdgeIIoT-dataset.csv'
               Clients=3, Rounds=2, Epochs/Client=2, Model='MLP'
Mode: Running Federated Learning Simulation Only

--- Starting Federated Learning Simulation ---
Config: 3 clients, 2 rounds, Model: MLP
Data Path: /content/drive/MyDrive/Colab Notebooks/datasets/ML-EdgeIIoT-dataset.csv
Generating synthetic data for clients...
Server: Initialized global MLP model (Build Input Shape: (27,), Classes: 2).

--- Server: Starting FL Round 1 with 2 clients: ['fl_client_1', 'fl_client_0'] ---


  super().__init__(activity_regularizer=activity_regularizer, **kwargs)
  super().__init__(activity_regularizer=activity_regularizer, **kwargs)
  super().__init__(activity_regularizer=activity_regularizer, **kwargs)



--- Server: Starting FL Round 2 with 2 clients: ['fl_client_2', 'fl_client_1'] ---


  super().__init__(activity_regularizer=activity_regularizer, **kwargs)
  super().__init__(activity_regularizer=activity_regularizer, **kwargs)



--- Performing Final Global Model Evaluation ---




Server: Global model saved to /content/federated_ids_ai_project/models/fl_global_mlp_model_20250528_234414.h5
Server: Aggregation history saved to /content/federated_ids_ai_project/federated_outputs/fl_agg_hist_20250528_234414.json

--- Federated Learning Simulation Completed ---
Final Global Model Accuracy: 0.4583

📊 Comprehensive simulation results saved to: /content/federated_ids_ai_project/federated_outputs/fl_results_mlp_20250528_234414.json

🏁 Federated Learning Main Execution Finished.

✅ Section 9 (Federated Learning - Main Execution Block) is ready.
