In [None]:
import numpy as np
import pandas as pd
import random
import cv2
import os
from tqdm import tqdm

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelBinarizer
from sklearn.model_selection import train_test_split
from sklearn.utils import shuffle
from sklearn.metrics import accuracy_score

import tensorflow as tf
from tensorflow import expand_dims
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, Input, Lambda
from tensorflow.keras.layers import MaxPooling2D
from tensorflow.keras.layers import Activation
from tensorflow.keras.layers import Flatten
from tensorflow.keras.layers import Dense
from tensorflow.keras.optimizers import SGD
from tensorflow.keras import backend as K

In [None]:
!pip install imutils
from imutils import paths

In [None]:
import time

# Initialize timing variable
cumulative_time = 0.0

def create_clients(image_list, label_list, num_clients=100, initial='clients'):
    """
    Create client data shards for federated learning
    
    Returns: 
        A dictionary with keys as clients' names and values as data shards
        (tuples of images and label lists)
    
    Args: 
        image_list: List of numpy arrays of training images
        label_list: List of binarized labels for each image
        num_clients: Number of federated members (clients)
        initial: The clients' name prefix, e.g., 'client_1'
    """
    # Create a list of client names
    client_names = [f'{initial}_{i+1}' for i in range(num_clients)]

    # Sort data for non-IID distribution
    max_y = np.argmax(label_list, axis=-1)
    sorted_zip = sorted(zip(max_y, label_list, image_list), key=lambda x: x[0])
    data = [(x, y) for _, y, x in sorted_zip]

    # Shard data and place at each client
    size = len(data) // num_clients
    shards = [data[i:i + size] for i in range(0, size * num_clients, size)]

    # Number of clients must equal number of shards
    assert(len(shards) == len(client_names))

    return {client_names[i]: shards[i] for i in range(len(client_names))}


def batch_data(data_shard, batch_size=32):
    """
    Batches data for training with proper shuffling
    
    Args:
        data_shard: List of (image, label) tuples
        batch_size: Size of each batch
    
    Returns:
        TensorFlow dataset batched and shuffled
    """
    # Separate images and labels
    data_X = np.array([data[0] for data in data_shard])
    data_Y = np.array([data[1] for data in data_shard])
    
    # Create TF dataset with shuffling and batching
    dataset = tf.data.Dataset.from_tensor_slices((data_X, data_Y))
    dataset = dataset.shuffle(buffer_size=len(data_shard))
    dataset = dataset.batch(batch_size)
    dataset = dataset.prefetch(tf.data.AUTOTUNE)
    
    return dataset


def weight_scalling_factor(clients_batched, client_name):
    """
    Calculate scaling factor for a client based on dataset size
    
    Formula: N_k / N (client's data size / total data size)
    
    Args:
        clients_batched: Dictionary of client datasets
        client_name: Name of the current client
    
    Returns:
        Scaling factor for the client
    """
    # Get the batch data
    client_dataset = clients_batched[client_name]
    
    # Count number of samples for this client
    client_len = sum([len(batch[0]) for batch in client_dataset])
    
    # Count total number of samples across all participating clients
    total_len = sum([sum([len(batch[0]) for batch in client_data]) 
                     for client_data in clients_batched.values()])
    
    return client_len / total_len


def scale_model_weights(weights, scalar):
    """
    Scale model weights by a scalar value
    
    Args:
        weights: List of numpy arrays (model weights)
        scalar: Scaling factor
    
    Returns:
        Scaled weights
    """
    return [w * scalar for w in weights]


def apply_threshold_normalization(weight_updates, threshold_percentile=95):
    """
    Apply threshold-based normalization to prevent excessive updates
    
    This constrains large weight updates that could destabilize training,
    particularly useful when dealing with noisy or anomalous client data.
    
    Args:
        weight_updates: List of weight arrays
        threshold_percentile: Percentile to use as clipping threshold (default: 95)
    
    Returns:
        Normalized weight updates
    """
    normalized_updates = []
    
    for weight_layer in weight_updates:
        # Calculate threshold based on percentile of absolute values
        abs_values = np.abs(weight_layer)
        threshold = np.percentile(abs_values, threshold_percentile)
        
        # Clip values exceeding threshold
        clipped_weights = np.clip(weight_layer, -threshold, threshold)
        normalized_updates.append(clipped_weights)
    
    return normalized_updates


def sum_scaled_weights(scaled_weight_list):
    """
    Sum scaled weights from multiple clients with threshold normalization
    
    Implements: W_t = Σ(N_k/N * W_k^t) with threshold normalization
    
    Args:
        scaled_weight_list: List of scaled weight lists from clients
    
    Returns:
        Aggregated weights
    """
    # Initialize with zeros matching the shape of first client's weights
    avg_weights = [np.zeros_like(w) for w in scaled_weight_list[0]]
    
    # Sum all scaled weights
    for scaled_weights in scaled_weight_list:
        for i, weight_layer in enumerate(scaled_weights):
            avg_weights[i] += weight_layer
    
    # Apply threshold normalization to prevent extreme updates
    avg_weights = apply_threshold_normalization(avg_weights, threshold_percentile=95)
    
    return avg_weights


def compute_weight_divergence(old_weights, new_weights):
    """
    Compute L2 divergence between old and new weights
    Useful for monitoring model stability
    
    Args:
        old_weights: Previous model weights
        new_weights: New model weights
    
    Returns:
        L2 norm of the difference
    """
    divergence = 0.0
    for old_w, new_w in zip(old_weights, new_weights):
        divergence += np.sum((old_w - new_w) ** 2)
    return np.sqrt(divergence)


print("✓ Federated Learning helper functions loaded successfully")
print("  - create_clients: Creates client data shards for federated learning")
print("  - batch_data: Creates batched TF datasets with shuffling")
print("  - weight_scalling_factor: Computes client contribution weight (N_k/N)")
print("  - scale_model_weights: Scales weights by scalar factor")
print("  - apply_threshold_normalization: Clips extreme weight updates")
print("  - sum_scaled_weights: Aggregates scaled weights with normalization")
print("  - compute_weight_divergence: Monitors model update stability")

In [None]:
def test_model(X_test, Y_test, model, comm_round):
    global cumulative_time  # Ensure we're modifying the global cumulative time variable

    start_time = time.time()

    preds = model.predict(X_test)
    cce = tf.keras.losses.CategoricalCrossentropy()
    loss = float(cce(Y_test, preds).numpy())
    y_true = np.argmax(Y_test, axis=1)
    y_pred = np.argmax(preds, axis=1)
    acc = accuracy_score(y_true, y_pred)

    end_time = time.time()
    elapsed_time = end_time - start_time
    cumulative_time += elapsed_time

    print('comm_round: {} | global_acc: {:.3%} | global_loss: {:.4f} | cumulative_time: {:.2f} seconds'.format(
        comm_round, acc, loss, cumulative_time
    ))

    return acc, loss


In [None]:
# ============================================================================
# IMPROVED MACNN ARCHITECTURE WITH BETTER INITIALIZATION
# ============================================================================

import tensorflow as tf
from tensorflow.keras import layers, models, optimizers, regularizers

def attention_module(inputs):
    """Squeeze-and-Excitation block with improved design"""
    filters = inputs.shape[-1]
    reduction = max(filters // 16, 1)

    se = layers.GlobalAveragePooling2D()(inputs)
    se = layers.Dense(units=reduction, activation='relu', 
                      kernel_initializer='he_normal')(se)
    se = layers.Dense(units=filters, activation='sigmoid',
                      kernel_initializer='he_normal')(se)
    se = layers.Reshape((1, 1, filters))(se)

    return layers.Multiply()([inputs, se])


def spatial_attention(input_feature):
    """Spatial attention mechanism"""
    avg_pool = layers.Lambda(lambda x: tf.reduce_mean(x, axis=3, keepdims=True))(input_feature)
    max_pool = layers.Lambda(lambda x: tf.reduce_max(x, axis=3, keepdims=True))(input_feature)
    concat = layers.Concatenate(axis=3)([avg_pool, max_pool])

    attention_map = layers.Conv2D(filters=1, kernel_size=7, strides=1, padding='same', 
                                   activation='sigmoid', kernel_initializer='he_normal')(concat)
    return layers.Multiply()([input_feature, attention_map])


def channel_attention(input_feature, ratio=8):
    """Channel attention mechanism with shared weights"""
    channel = input_feature.shape[-1]
    reduction = max(channel // ratio, 1)
    shared_dense_one = layers.Dense(reduction, activation='relu', 
                                     kernel_initializer='he_normal')
    shared_dense_two = layers.Dense(channel, activation='sigmoid',
                                     kernel_initializer='he_normal')

    avg_pool = layers.GlobalAveragePooling2D()(input_feature)
    avg_out = shared_dense_two(shared_dense_one(avg_pool))

    max_pool = layers.GlobalMaxPooling2D()(input_feature)
    max_out = shared_dense_two(shared_dense_one(max_pool))

    attention = layers.Add()([avg_out, max_out])
    attention = layers.Reshape((1, 1, channel))(attention)
    return layers.Multiply()([input_feature, attention])


def MACNN(input_shape, num_classes, l2_reg=1e-4):
    """
    Improved Multi-Attention CNN with:
    - Better weight initialization (He normal)
    - L2 regularization to prevent overfitting
    - Optimized dropout rates
    - Proper activation functions
    
    Args:
        input_shape: Shape of input images (height, width, channels)
        num_classes: Number of output classes
        l2_reg: L2 regularization factor
    
    Returns:
        Keras Model
    """
    inputs = layers.Input(shape=input_shape)

    # First Convolutional Block
    x = layers.Conv2D(32, (3, 3), activation='relu', padding='same',
                      kernel_initializer='he_normal',
                      kernel_regularizer=regularizers.l2(l2_reg))(inputs)
    x = layers.BatchNormalization()(x)
    x = layers.Conv2D(32, (3, 3), activation='relu', padding='same',
                      kernel_initializer='he_normal',
                      kernel_regularizer=regularizers.l2(l2_reg))(x)
    x = layers.BatchNormalization()(x)
    x = layers.MaxPooling2D((2, 2))(x)
    x = layers.Dropout(0.25)(x)

    # Add Attention Module
    x = attention_module(x)
    x = spatial_attention(x)

    # Second Convolutional Block
    x = layers.Conv2D(64, (3, 3), activation='relu', padding='same',
                      kernel_initializer='he_normal',
                      kernel_regularizer=regularizers.l2(l2_reg))(x)
    x = layers.BatchNormalization()(x)
    x = layers.Conv2D(64, (3, 3), activation='relu', padding='same',
                      kernel_initializer='he_normal',
                      kernel_regularizer=regularizers.l2(l2_reg))(x)
    x = layers.BatchNormalization()(x)
    x = layers.MaxPooling2D((2, 2))(x)
    x = layers.Dropout(0.3)(x)

    # Add Attention Module
    x = attention_module(x)
    x = channel_attention(x)

    # Third Convolutional Block
    x = layers.Conv2D(128, (3, 3), activation='relu', padding='same',
                      kernel_initializer='he_normal',
                      kernel_regularizer=regularizers.l2(l2_reg))(x)
    x = layers.BatchNormalization()(x)
    x = layers.Conv2D(128, (3, 3), activation='relu', padding='same',
                      kernel_initializer='he_normal',
                      kernel_regularizer=regularizers.l2(l2_reg))(x)
    x = layers.BatchNormalization()(x)
    x = layers.MaxPooling2D((2, 2))(x)
    x = layers.Dropout(0.4)(x)

    # Flatten spatial dimensions before dense layers
    x = layers.Flatten()(x)
    
    # Dense Layers with proper regularization
    x = layers.Dense(512, activation='relu',
                     kernel_initializer='he_normal',
                     kernel_regularizer=regularizers.l2(l2_reg))(x)
    x = layers.BatchNormalization()(x)
    x = layers.Dropout(0.5)(x)
    
    x = layers.Dense(256, activation='relu',
                     kernel_initializer='he_normal',
                     kernel_regularizer=regularizers.l2(l2_reg))(x)
    x = layers.BatchNormalization()(x)
    x = layers.Dropout(0.5)(x)

    outputs = layers.Dense(num_classes, activation='softmax',
                           kernel_initializer='glorot_uniform')(x)

    model = models.Model(inputs, outputs)
    return model

print("✓ Improved MACNN architecture loaded")
print("  - Double conv layers per block for better feature extraction")
print("  - He normal initialization for ReLU activations")
print("  - L2 regularization to prevent overfitting")
print("  - Optimized dropout rates")
print("  - Additional dense layer for better classification")

In [None]:
from pathlib import Path
import tensorflow as tf
import numpy as np
from sklearn.preprocessing import LabelBinarizer

DATA_ROOT = Path.cwd() / "Datasets" / "RealAIGI"   
TARGET_SIZE = (256, 256)                                   
COLOR_MODE = "rgb"                                       

def load_local_dataset(data_root: Path,
                       target_size=(256, 256),
                       color_mode="rgb"):
    if not data_root.exists():
        raise FileNotFoundError(f"{data_root} does not exist")

    images = []
    labels = []

    for class_dir in sorted(data_root.iterdir()):
        if not class_dir.is_dir():
            continue
        for image_path in class_dir.glob("*"):
            if image_path.suffix.lower() not in {".jpg", ".jpeg", ".png", ".bmp", ".gif"}:
                continue  # skip unknown file types

            img = tf.keras.utils.load_img(
                image_path,
                target_size=target_size,
                color_mode=color_mode
            )
            arr = tf.keras.utils.img_to_array(img) / 255.0
            images.append(arr)
            labels.append(class_dir.name)

    if not images:
        raise ValueError(f"No images found under {data_root}")

    image_array = np.stack(images, axis=0)
    lb = LabelBinarizer()
    label_array = lb.fit_transform(labels)
    if label_array.ndim == 1:  # binary case
        label_array = label_array[:, np.newaxis]

    return image_array, label_array, lb.classes_

image_list, label_list, class_names = load_local_dataset(
    DATA_ROOT,
    target_size=TARGET_SIZE,
    color_mode=COLOR_MODE
)
num_classes = label_list.shape[1]
input_shape = image_list.shape[1:]

print(f"Loaded custom dataset: {image_list.shape[0]} samples, image shape {input_shape}")
print(f"Classes: {class_names}")

In [None]:
#split data into training and test set
X_train, X_test, y_train, y_test = train_test_split(image_list, 
                                                    label_list, 
                                                    test_size=0.1, 
                                                    random_state=42)

### IID

In [None]:
len(X_train), len(X_test), len(y_train), len(y_test)

In [None]:
#create clients
clients = create_clients(X_train, y_train, num_clients=100, initial='client')

In [None]:
# client_names = ['{}_{}'.format('client', i+1) for i in range(100)]
# s = clients['client_1'][0][1]*0
# for c in client_names:
#     sum = clients[c][0][1]
#     for i in range(1,378):
#         sum = sum + clients[c][i][1]
        
#     s = s + sum/378
# s

In [None]:
#process and batch the training data for each client
clients_batched = dict()
for (client_name, data) in clients.items():
    clients_batched[client_name] = batch_data(data)
    
#process and batch the test set  
test_batched = tf.data.Dataset.from_tensor_slices((X_test, y_test)).batch(len(y_test))

In [None]:
# ============================================================================
# IMPROVED TRAINING HYPERPARAMETERS
# ============================================================================

# Learning rate with scheduling
initial_lr = 0.01
comms_round = 100  # Increased from 50 to 100 for better convergence
local_epochs = 3   # Multiple local epochs per round (was 1)
batch_size = 32    # Batch size for training

# Loss and metrics
loss = 'categorical_crossentropy'
metrics = ['accuracy']

# Optimizer with momentum and decay
lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
    initial_learning_rate=initial_lr,
    decay_steps=comms_round * 10,  # Decay over rounds
    decay_rate=0.96,
    staircase=True
)

# Client selection parameters
num_clients_per_round = 10  # Number of clients to sample each round
min_clients_per_round = 5   # Minimum clients needed

print(f"✓ Training hyperparameters configured:")
print(f"  - Initial learning rate: {initial_lr}")
print(f"  - Communication rounds: {comms_round}")
print(f"  - Local epochs per round: {local_epochs}")
print(f"  - Batch size: {batch_size}")
print(f"  - Clients per round: {num_clients_per_round}")
print(f"  - Learning rate schedule: ExponentialDecay (decay_rate=0.96)")

In [None]:
#initialize global model
global_model = MACNN(input_shape, num_classes)
global_acc_list = []
global_loss_list = []

In [None]:
# ============================================================================
# IMPROVED FEDERATED LEARNING TRAINING LOOP (IID)
# ============================================================================

# Initialize tracking lists
global_acc_list = []
global_loss_list = []
divergence_list = []
per_round_time = []

# Reset cumulative time
cumulative_time = 0.0

print("="*70)
print("STARTING IMPROVED FEDERATED LEARNING - IID SETTING")
print("="*70)
print(f"Total clients: {len(clients_batched)}")
print(f"Clients per round: {num_clients_per_round}")
print(f"Communication rounds: {comms_round}")
print(f"Local epochs: {local_epochs}")
print("="*70)

# Training loop
for comm_round in range(comms_round):
    round_start_time = time.time()
    
    # Get current global weights
    global_weights = global_model.get_weights()
    scaled_local_weight_list = []
    
    # Sample clients for this round
    all_client_names = list(clients_batched.keys())
    num_available_clients = len(all_client_names)
    num_selected = min(num_clients_per_round, num_available_clients)
    client_names = random.sample(all_client_names, k=num_selected)
    
    if debug:
        print(f"\nRound {comm_round}: Selected clients: {client_names[:5]}...")
    
    # Local training on selected clients
    for client in client_names:
        # Create local model
        local_model = MACNN(input_shape, num_classes, l2_reg=1e-4)
        
        # Create optimizer with current learning rate
        current_lr = initial_lr * (0.96 ** (comm_round // 10))
        local_optimizer = SGD(
            learning_rate=current_lr,
            momentum=0.9,
            nesterov=True
        )
        
        # Compile model
        local_model.compile(
            loss=loss,
            optimizer=local_optimizer,
            metrics=metrics
        )
        
        # Set weights from global model
        local_model.set_weights(global_weights)
        
        # Train for multiple local epochs
        local_model.fit(
            clients_batched[client],
            epochs=local_epochs,
            verbose=0
        )
        
        # Calculate scaling factor based on client data size
        scaling_factor = weight_scalling_factor(clients_batched, client)
        
        # Scale model weights
        scaled_weights = scale_model_weights(
            local_model.get_weights(),
            scaling_factor
        )
        
        scaled_local_weight_list.append(scaled_weights)
        
        # Clean up
        del local_model
        tf.keras.backend.clear_session()
    
    # Aggregate weights with threshold normalization
    average_weights = sum_scaled_weights(scaled_local_weight_list)
    
    # Compute weight divergence for monitoring
    divergence = compute_weight_divergence(global_weights, average_weights)
    divergence_list.append(divergence)
    
    # Update global model
    global_model.set_weights(average_weights)
    
    # Evaluate on test set
    for (X_test, Y_test) in test_batched:
        global_acc, global_loss = test_model(X_test, Y_test, global_model, comm_round)
        global_acc_list.append(global_acc)
        global_loss_list.append(global_loss)
    
    # Track round time
    round_time = time.time() - round_start_time
    per_round_time.append(round_time)
    
    # Print progress every 10 rounds
    if (comm_round + 1) % 10 == 0:
        avg_time = np.mean(per_round_time[-10:])
        print(f"\n[Round {comm_round+1}/{comms_round}] "
              f"Acc: {global_acc:.4f} | Loss: {global_loss:.4f} | "
              f"Divergence: {divergence:.4f} | "
              f"Avg Time: {avg_time:.2f}s")

print("\n" + "="*70)
print("FEDERATED LEARNING COMPLETED - IID SETTING")
print("="*70)
print(f"Final Accuracy: {global_acc_list[-1]:.4f}")
print(f"Final Loss: {global_loss_list[-1]:.4f}")
print(f"Total Time: {cumulative_time:.2f} seconds")
print("="*70)

In [None]:
# ============================================================================
# COMPREHENSIVE VISUALIZATION AND ANALYSIS (IID)
# ============================================================================

import matplotlib.pyplot as plt
import seaborn as sns

sns.set_style("whitegrid")

fig = plt.figure(figsize=(20, 12))

# 1. Accuracy over communication rounds
ax1 = plt.subplot(3, 3, 1)
plt.plot(range(len(global_acc_list)), global_acc_list, 'b-', linewidth=2, label='Global Accuracy')
plt.fill_between(range(len(global_acc_list)), global_acc_list, alpha=0.3)
plt.xlabel('Communication Round', fontsize=12)
plt.ylabel('Accuracy', fontsize=12)
plt.title('Global Model Accuracy (IID)', fontsize=14, fontweight='bold')
plt.grid(True, alpha=0.3)
plt.legend()

# 2. Loss over communication rounds
ax2 = plt.subplot(3, 3, 2)
plt.plot(range(len(global_loss_list)), global_loss_list, 'r-', linewidth=2, label='Global Loss')
plt.fill_between(range(len(global_loss_list)), global_loss_list, alpha=0.3, color='red')
plt.xlabel('Communication Round', fontsize=12)
plt.ylabel('Loss', fontsize=12)
plt.title('Global Model Loss (IID)', fontsize=14, fontweight='bold')
plt.grid(True, alpha=0.3)
plt.legend()

# 3. Weight Divergence
ax3 = plt.subplot(3, 3, 3)
plt.plot(range(len(divergence_list)), divergence_list, 'g-', linewidth=2, label='Weight Divergence')
plt.xlabel('Communication Round', fontsize=12)
plt.ylabel('L2 Divergence', fontsize=12)
plt.title('Model Update Divergence (IID)', fontsize=14, fontweight='bold')
plt.grid(True, alpha=0.3)
plt.legend()

# 4. Moving average accuracy (window=10)
ax4 = plt.subplot(3, 3, 4)
window_size = 10
if len(global_acc_list) >= window_size:
    moving_avg = np.convolve(global_acc_list, np.ones(window_size)/window_size, mode='valid')
    plt.plot(range(len(moving_avg)), moving_avg, 'purple', linewidth=2, label=f'MA({window_size})')
    plt.xlabel('Communication Round', fontsize=12)
    plt.ylabel('Accuracy (MA)', fontsize=12)
    plt.title(f'Moving Average Accuracy (IID)', fontsize=14, fontweight='bold')
    plt.grid(True, alpha=0.3)
    plt.legend()

# 5. Accuracy improvement rate
ax5 = plt.subplot(3, 3, 5)
if len(global_acc_list) > 1:
    acc_diff = np.diff(global_acc_list)
    plt.plot(range(len(acc_diff)), acc_diff, 'orange', linewidth=2, label='Δ Accuracy')
    plt.axhline(y=0, color='black', linestyle='--', alpha=0.5)
    plt.xlabel('Communication Round', fontsize=12)
    plt.ylabel('Accuracy Change', fontsize=12)
    plt.title('Accuracy Improvement Rate (IID)', fontsize=14, fontweight='bold')
    plt.grid(True, alpha=0.3)
    plt.legend()

# 6. Loss improvement rate
ax6 = plt.subplot(3, 3, 6)
if len(global_loss_list) > 1:
    loss_diff = np.diff(global_loss_list)
    plt.plot(range(len(loss_diff)), loss_diff, 'brown', linewidth=2, label='Δ Loss')
    plt.axhline(y=0, color='black', linestyle='--', alpha=0.5)
    plt.xlabel('Communication Round', fontsize=12)
    plt.ylabel('Loss Change', fontsize=12)
    plt.title('Loss Improvement Rate (IID)', fontsize=14, fontweight='bold')
    plt.grid(True, alpha=0.3)
    plt.legend()

# 7. Cumulative time per round
ax7 = plt.subplot(3, 3, 7)
if len(per_round_time) > 0:
    plt.plot(range(len(per_round_time)), np.cumsum(per_round_time), 'cyan', linewidth=2, label='Cumulative Time')
    plt.xlabel('Communication Round', fontsize=12)
    plt.ylabel('Time (seconds)', fontsize=12)
    plt.title('Cumulative Training Time (IID)', fontsize=14, fontweight='bold')
    plt.grid(True, alpha=0.3)
    plt.legend()

# 8. Statistics summary table
ax8 = plt.subplot(3, 3, 8)
ax8.axis('off')
stats_text = f"""
TRAINING STATISTICS (IID)

Total Rounds: {len(global_acc_list)}
─────────────────────────
Accuracy:
  Initial:  {global_acc_list[0]:.4f}
  Final:    {global_acc_list[-1]:.4f}
  Max:      {max(global_acc_list):.4f}
  Mean:     {np.mean(global_acc_list):.4f}
  Std:      {np.std(global_acc_list):.4f}
─────────────────────────
Loss:
  Initial:  {global_loss_list[0]:.4f}
  Final:    {global_loss_list[-1]:.4f}
  Min:      {min(global_loss_list):.4f}
  Mean:     {np.mean(global_loss_list):.4f}
  Std:      {np.std(global_loss_list):.4f}
─────────────────────────
Training Time:
  Total:    {cumulative_time:.2f}s
  Per Round:{cumulative_time/len(global_acc_list):.2f}s
─────────────────────────
Model Stability:
  Avg Divergence: {np.mean(divergence_list):.4f}
  Max Divergence: {max(divergence_list):.4f}
"""
ax8.text(0.1, 0.5, stats_text, fontsize=11, family='monospace',
         verticalalignment='center', bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.3))

# 9. Accuracy distribution
ax9 = plt.subplot(3, 3, 9)
plt.hist(global_acc_list, bins=30, edgecolor='black', alpha=0.7, color='skyblue')
plt.axvline(np.mean(global_acc_list), color='red', linestyle='--', linewidth=2, label='Mean')
plt.axvline(np.median(global_acc_list), color='green', linestyle='--', linewidth=2, label='Median')
plt.xlabel('Accuracy', fontsize=12)
plt.ylabel('Frequency', fontsize=12)
plt.title('Accuracy Distribution (IID)', fontsize=14, fontweight='bold')
plt.legend()
plt.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('fl_iid_comprehensive_analysis.png', dpi=300, bbox_inches='tight')
plt.show()

print("\n" + "="*70)
print("IID SETTING - COMPREHENSIVE ANALYSIS")
print("="*70)
print(f"✓ Visualization saved as 'fl_iid_comprehensive_analysis.png'")
print(f"✓ Total communication rounds: {len(global_acc_list)}")
print(f"✓ Final accuracy: {global_acc_list[-1]:.4f} (improvement: {global_acc_list[-1]-global_acc_list[0]:.4f})")
print(f"✓ Final loss: {global_loss_list[-1]:.4f}")
print("="*70)

In [None]:
iid_df = pd.DataFrame(list(zip(global_acc_list, global_loss_list)), columns =['global_acc_list', 'global_loss_list'])
iid_df.to_csv('MNIST_IID.csv',index=False)

### Non-IID

In [None]:
def create_clients(image_list, label_list, num_clients=100, initial='clients'):
    ''' return: a dictionary with keys clients' names and value as 
                data shards - tuple of images and label lists.
        args: 
            image_list: a list of numpy arrays of training images
            label_list:a list of binarized labels for each image
            num_client: number of fedrated members (clients)
            initials: the clients'name prefix, e.g, clients_1 
    '''

    #create a list of client names
    client_names = ['{}_{}'.format(initial, i+1) for i in range(num_clients)]

    #randomize the data
    # data = list(zip(image_list, label_list))
    # random.shuffle(data)  # <- IID
    
    # sort data for non-iid
    max_y = np.argmax(label_list, axis=-1)
    sorted_zip = sorted(zip(max_y, label_list, image_list), key=lambda x: x[0])
    data = [(x,y) for _,y,x in sorted_zip]

    #shard data and place at each client
    size = len(data)//num_clients
    shards = [data[i:i + size] for i in range(0, size*num_clients, size)]

    #number of clients must equal number of shards
    assert(len(shards) == len(client_names))

    return {client_names[i] : shards[i] for i in range(len(client_names))} 

In [None]:
len(X_train), len(X_test), len(y_train), len(y_test)

In [None]:
#create clients
clients = create_clients(X_train, y_train, num_clients=100, initial='client')

In [None]:
#process and batch the training data for each client
clients_batched = dict()
for (client_name, data) in clients.items():
    clients_batched[client_name] = batch_data(data)
    
#process and batch the test set  
test_batched = tf.data.Dataset.from_tensor_slices((X_test, y_test)).batch(len(y_test))

In [None]:
# ============================================================================
# TRAINING HYPERPARAMETERS FOR NON-IID
# ============================================================================

# Toggle to quickly sanity-check the training loop without long runtimes
quick_debug_mode = True  # Set to False for full-length training

# Base hyperparameters (full training defaults)
initial_lr_noniid = 0.008  # Slightly lower LR for non-IID stability
comms_round_noniid = 150   # More rounds needed for non-IID convergence
local_epochs_noniid = 3
batch_size_noniid = 32
num_clients_per_round_noniid = 10

# Runtime safety controls (None disables the safeguard)
max_batches_per_client_noniid = None  # Limits batches per client per round
max_total_minutes_noniid = None       # Stops training after X minutes
target_accuracy_noniid = None         # Stops when accuracy threshold reached

if quick_debug_mode:
    # Shrink the workload for quick experimentation / debugging sessions
    comms_round_noniid = min(comms_round_noniid, 20)
    local_epochs_noniid = min(local_epochs_noniid, 2)
    num_clients_per_round_noniid = min(num_clients_per_round_noniid, 5)
    max_batches_per_client_noniid = 5
    max_total_minutes_noniid = 5  # minutes
    target_accuracy_noniid = 0.65

loss = 'categorical_crossentropy'
metrics = ['accuracy']

print("✓ Non-IID Training hyperparameters configured:")
print(f"  - Quick debug mode: {'ENABLED' if quick_debug_mode else 'disabled'}")
print(f"  - Initial learning rate: {initial_lr_noniid}")
print(f"  - Communication rounds (max): {comms_round_noniid}")
print(f"  - Local epochs per round: {local_epochs_noniid}")
print(f"  - Batch size: {batch_size_noniid}")
print(f"  - Clients per round: {num_clients_per_round_noniid}")
print(f"  - Max batches per client: {max_batches_per_client_noniid if max_batches_per_client_noniid is not None else 'unlimited'}")
print(f"  - Time budget (minutes): {max_total_minutes_noniid if max_total_minutes_noniid is not None else 'unlimited'}")
print(f"  - Early-stop accuracy: {target_accuracy_noniid if target_accuracy_noniid is not None else 'disabled'}")

In [None]:
#initialize global model
input_shape = (32, 32, 3)
num_classes = 10
global_model = MACNN(input_shape, num_classes)
global_acc_list = []
global_loss_list = []

In [None]:
# ============================================================================
# IMPROVED FEDERATED LEARNING TRAINING LOOP (NON-IID)
# ============================================================================

# Initialize tracking lists for Non-IID
global_acc_list_noniid = []
global_loss_list_noniid = []
divergence_list_noniid = []
per_round_time_noniid = []

# Reset cumulative time
cumulative_time = 0.0

print("="*70)
print("STARTING IMPROVED FEDERATED LEARNING - NON-IID SETTING")
print("="*70)
print(f"Total clients: {len(clients_batched)}")
print(f"Clients per round: {num_clients_per_round_noniid}")
print(f"Communication rounds (max): {comms_round_noniid}")
print(f"Local epochs: {local_epochs_noniid}")
print(f"⚠️  NON-IID: Clients have non-uniformly distributed data")
print("="*70)

overall_start_time = time.time()
stop_reason = None
interrupted = False

try:
    for comm_round in range(comms_round_noniid):
        round_start_time = time.time()

        # Get current global weights
        global_weights = global_model.get_weights()
        scaled_local_weight_list = []

        # Sample clients for this round
        all_client_names = list(clients_batched.keys())
        num_available_clients = len(all_client_names)
        num_selected = min(num_clients_per_round_noniid, num_available_clients)

        if num_selected == 0:
            stop_reason = "No clients available for selection"
            break

        client_names = random.sample(all_client_names, k=num_selected)

        if debug:
            print(f"\nRound {comm_round}: Selected clients: {client_names[:5]}...")

        # Determine effective sample sizes for selected clients
        if max_batches_per_client_noniid is not None:
            client_effective_sizes = {
                name: min(len(clients[name]), max_batches_per_client_noniid * batch_size_noniid)
                for name in client_names
            }
        else:
            client_effective_sizes = {name: len(clients[name]) for name in client_names}

        round_total_samples = max(sum(client_effective_sizes.values()), 1)

        # Local training on selected clients
        for client in client_names:
            local_model = MACNN(input_shape, num_classes, l2_reg=1e-4)

            # Create optimizer with current learning rate (slow decay for non-IID)
            current_lr = initial_lr_noniid * (0.98 ** (comm_round // 15))
            local_optimizer = SGD(
                learning_rate=current_lr,
                momentum=0.9,
                nesterov=True
            )

            local_model.compile(
                loss=loss,
                optimizer=local_optimizer,
                metrics=metrics
            )

            # Set weights from global model
            local_model.set_weights(global_weights)

            # Prepare dataset (optionally capped for faster iterations)
            client_dataset = clients_batched[client]
            if max_batches_per_client_noniid is not None:
                client_dataset = client_dataset.take(max_batches_per_client_noniid)
            client_dataset = client_dataset.prefetch(tf.data.AUTOTUNE)

            # Train for multiple local epochs
            local_model.fit(
                client_dataset,
                epochs=local_epochs_noniid,
                verbose=0
            )

            # Calculate scaling factor based on effective client data size
            scaling_factor = client_effective_sizes[client] / round_total_samples

            # Scale model weights and collect
            scaled_weights = scale_model_weights(
                local_model.get_weights(),
                scaling_factor
            )
            scaled_local_weight_list.append(scaled_weights)

            # Clean up
            del local_model
            tf.keras.backend.clear_session()

        if not scaled_local_weight_list:
            stop_reason = "No client updates were gathered in this round"
            break

        # Aggregate weights with threshold normalization
        average_weights = sum_scaled_weights(scaled_local_weight_list)

        # Compute weight divergence for monitoring
        divergence = compute_weight_divergence(global_weights, average_weights)
        divergence_list_noniid.append(divergence)

        # Update global model
        global_model.set_weights(average_weights)

        # Evaluate on test set
        global_acc = None
        global_loss = None
        for (X_test, Y_test) in test_batched:
            global_acc, global_loss = test_model(X_test, Y_test, global_model, comm_round)
            global_acc_list_noniid.append(global_acc)
            global_loss_list_noniid.append(global_loss)

        # Track round time
        round_time = time.time() - round_start_time
        per_round_time_noniid.append(round_time)

        # Frequent progress updates for long runs
        if (comm_round + 1) % 5 == 0 or comm_round == 0:
            display_acc = f"{global_acc:.4f}" if global_acc is not None else "N/A"
            display_loss = f"{global_loss:.4f}" if global_loss is not None else "N/A"
            display_div = f"{divergence:.4f}"
            print(
                f"[Round {comm_round+1}/{comms_round_noniid}] "
                f"Acc: {display_acc} | Loss: {display_loss} | "
                f"Divergence: {display_div} | Time: {round_time:.2f}s"
            )

        # Early stopping conditions
        if target_accuracy_noniid is not None and global_acc is not None and global_acc >= target_accuracy_noniid:
            stop_reason = (
                f"Target accuracy reached ({global_acc:.4f} ≥ {target_accuracy_noniid:.4f})"
            )
            break

        elapsed_minutes = (time.time() - overall_start_time) / 60.0
        if max_total_minutes_noniid is not None and elapsed_minutes >= max_total_minutes_noniid:
            stop_reason = (
                f"Time budget exceeded ({elapsed_minutes:.1f} min ≥ {max_total_minutes_noniid} min)"
            )
            break
except KeyboardInterrupt:
    interrupted = True
    stop_reason = "Manual interruption detected (KeyboardInterrupt)"
finally:
    total_elapsed_minutes = (time.time() - overall_start_time) / 60.0
    print("\n" + "="*70)
    print("FEDERATED LEARNING SUMMARY - NON-IID SETTING")
    print("="*70)

    if interrupted:
        print("⚠️  Training interrupted by user; partial progress preserved.")
    elif stop_reason is not None:
        print(f"⚠️  Training stopped early: {stop_reason}")
    else:
        print("✅ Training completed all scheduled rounds.")

    rounds_completed = len(global_acc_list_noniid)
    print(f"Rounds completed: {rounds_completed}")
    print(f"Total elapsed time: {total_elapsed_minutes:.2f} minutes")

    if global_acc_list_noniid:
        print(f"Final Accuracy: {global_acc_list_noniid[-1]:.4f}")
        print(f"Final Loss: {global_loss_list_noniid[-1]:.4f}")
    else:
        print("No evaluation metrics were recorded before stopping.")

    print(f"Total time tracked via test_model: {cumulative_time:.2f} seconds")
    print("="*70)

In [None]:
# ============================================================================
# COMPREHENSIVE VISUALIZATION AND ANALYSIS (NON-IID)
# ============================================================================

import matplotlib.pyplot as plt
import seaborn as sns

sns.set_style("whitegrid")

fig = plt.figure(figsize=(20, 12))

# 1. Accuracy over communication rounds
ax1 = plt.subplot(3, 3, 1)
plt.plot(range(len(global_acc_list_noniid)), global_acc_list_noniid, 'b-', linewidth=2, label='Global Accuracy')
plt.fill_between(range(len(global_acc_list_noniid)), global_acc_list_noniid, alpha=0.3)
plt.xlabel('Communication Round', fontsize=12)
plt.ylabel('Accuracy', fontsize=12)
plt.title('Global Model Accuracy (Non-IID)', fontsize=14, fontweight='bold')
plt.grid(True, alpha=0.3)
plt.legend()

# 2. Loss over communication rounds
ax2 = plt.subplot(3, 3, 2)
plt.plot(range(len(global_loss_list_noniid)), global_loss_list_noniid, 'r-', linewidth=2, label='Global Loss')
plt.fill_between(range(len(global_loss_list_noniid)), global_loss_list_noniid, alpha=0.3, color='red')
plt.xlabel('Communication Round', fontsize=12)
plt.ylabel('Loss', fontsize=12)
plt.title('Global Model Loss (Non-IID)', fontsize=14, fontweight='bold')
plt.grid(True, alpha=0.3)
plt.legend()

# 3. Weight Divergence
ax3 = plt.subplot(3, 3, 3)
plt.plot(range(len(divergence_list_noniid)), divergence_list_noniid, 'g-', linewidth=2, label='Weight Divergence')
plt.xlabel('Communication Round', fontsize=12)
plt.ylabel('L2 Divergence', fontsize=12)
plt.title('Model Update Divergence (Non-IID)', fontsize=14, fontweight='bold')
plt.grid(True, alpha=0.3)
plt.legend()

# 4. Moving average accuracy (window=10)
ax4 = plt.subplot(3, 3, 4)
window_size = 10
if len(global_acc_list_noniid) >= window_size:
    moving_avg = np.convolve(global_acc_list_noniid, np.ones(window_size)/window_size, mode='valid')
    plt.plot(range(len(moving_avg)), moving_avg, 'purple', linewidth=2, label=f'MA({window_size})')
    plt.xlabel('Communication Round', fontsize=12)
    plt.ylabel('Accuracy (MA)', fontsize=12)
    plt.title(f'Moving Average Accuracy (Non-IID)', fontsize=14, fontweight='bold')
    plt.grid(True, alpha=0.3)
    plt.legend()

# 5. Accuracy improvement rate
ax5 = plt.subplot(3, 3, 5)
if len(global_acc_list_noniid) > 1:
    acc_diff = np.diff(global_acc_list_noniid)
    plt.plot(range(len(acc_diff)), acc_diff, 'orange', linewidth=2, label='Δ Accuracy')
    plt.axhline(y=0, color='black', linestyle='--', alpha=0.5)
    plt.xlabel('Communication Round', fontsize=12)
    plt.ylabel('Accuracy Change', fontsize=12)
    plt.title('Accuracy Improvement Rate (Non-IID)', fontsize=14, fontweight='bold')
    plt.grid(True, alpha=0.3)
    plt.legend()

# 6. Loss improvement rate
ax6 = plt.subplot(3, 3, 6)
if len(global_loss_list_noniid) > 1:
    loss_diff = np.diff(global_loss_list_noniid)
    plt.plot(range(len(loss_diff)), loss_diff, 'brown', linewidth=2, label='Δ Loss')
    plt.axhline(y=0, color='black', linestyle='--', alpha=0.5)
    plt.xlabel('Communication Round', fontsize=12)
    plt.ylabel('Loss Change', fontsize=12)
    plt.title('Loss Improvement Rate (Non-IID)', fontsize=14, fontweight='bold')
    plt.grid(True, alpha=0.3)
    plt.legend()

# 7. Cumulative time per round
ax7 = plt.subplot(3, 3, 7)
if len(per_round_time_noniid) > 0:
    plt.plot(range(len(per_round_time_noniid)), np.cumsum(per_round_time_noniid), 'cyan', linewidth=2, label='Cumulative Time')
    plt.xlabel('Communication Round', fontsize=12)
    plt.ylabel('Time (seconds)', fontsize=12)
    plt.title('Cumulative Training Time (Non-IID)', fontsize=14, fontweight='bold')
    plt.grid(True, alpha=0.3)
    plt.legend()

# 8. Statistics summary table
ax8 = plt.subplot(3, 3, 8)
ax8.axis('off')
stats_text = f"""
TRAINING STATISTICS (NON-IID)

Total Rounds: {len(global_acc_list_noniid)}
─────────────────────────
Accuracy:
  Initial:  {global_acc_list_noniid[0]:.4f}
  Final:    {global_acc_list_noniid[-1]:.4f}
  Max:      {max(global_acc_list_noniid):.4f}
  Mean:     {np.mean(global_acc_list_noniid):.4f}
  Std:      {np.std(global_acc_list_noniid):.4f}
─────────────────────────
Loss:
  Initial:  {global_loss_list_noniid[0]:.4f}
  Final:    {global_loss_list_noniid[-1]:.4f}
  Min:      {min(global_loss_list_noniid):.4f}
  Mean:     {np.mean(global_loss_list_noniid):.4f}
  Std:      {np.std(global_loss_list_noniid):.4f}
─────────────────────────
Training Time:
  Total:    {cumulative_time:.2f}s
  Per Round:{cumulative_time/len(global_acc_list_noniid):.2f}s
─────────────────────────
Model Stability:
  Avg Divergence: {np.mean(divergence_list_noniid):.4f}
  Max Divergence: {max(divergence_list_noniid):.4f}
"""
ax8.text(0.1, 0.5, stats_text, fontsize=11, family='monospace',
         verticalalignment='center', bbox=dict(boxstyle='round', facecolor='lightblue', alpha=0.3))

# 9. Accuracy distribution
ax9 = plt.subplot(3, 3, 9)
plt.hist(global_acc_list_noniid, bins=30, edgecolor='black', alpha=0.7, color='lightcoral')
plt.axvline(np.mean(global_acc_list_noniid), color='red', linestyle='--', linewidth=2, label='Mean')
plt.axvline(np.median(global_acc_list_noniid), color='green', linestyle='--', linewidth=2, label='Median')
plt.xlabel('Accuracy', fontsize=12)
plt.ylabel('Frequency', fontsize=12)
plt.title('Accuracy Distribution (Non-IID)', fontsize=14, fontweight='bold')
plt.legend()
plt.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('fl_noniid_comprehensive_analysis.png', dpi=300, bbox_inches='tight')
plt.show()

print("\n" + "="*70)
print("NON-IID SETTING - COMPREHENSIVE ANALYSIS")
print("="*70)
print(f"✓ Visualization saved as 'fl_noniid_comprehensive_analysis.png'")
print(f"✓ Total communication rounds: {len(global_acc_list_noniid)}")
print(f"✓ Final accuracy: {global_acc_list_noniid[-1]:.4f} (improvement: {global_acc_list_noniid[-1]-global_acc_list_noniid[0]:.4f})")
print(f"✓ Final loss: {global_loss_list_noniid[-1]:.4f}")
print("="*70)

In [None]:
noniid_df = pd.DataFrame(list(zip(global_acc_list, global_loss_list)), columns =['global_acc_list', 'global_loss_list'])
noniid_df.to_csv('CIFAR-10_Non-IID.csv',index=False)

In [None]:
# ============================================================================
# COMPARISON: IID vs NON-IID
# ============================================================================

fig, axes = plt.subplots(2, 3, figsize=(18, 10))

# 1. Accuracy Comparison
axes[0, 0].plot(range(len(global_acc_list)), global_acc_list, 'b-', linewidth=2, label='IID', alpha=0.8)
axes[0, 0].plot(range(len(global_acc_list_noniid)), global_acc_list_noniid, 'r-', linewidth=2, label='Non-IID', alpha=0.8)
axes[0, 0].set_xlabel('Communication Round', fontsize=11)
axes[0, 0].set_ylabel('Accuracy', fontsize=11)
axes[0, 0].set_title('Accuracy: IID vs Non-IID', fontsize=13, fontweight='bold')
axes[0, 0].legend()
axes[0, 0].grid(True, alpha=0.3)

# 2. Loss Comparison
axes[0, 1].plot(range(len(global_loss_list)), global_loss_list, 'b-', linewidth=2, label='IID', alpha=0.8)
axes[0, 1].plot(range(len(global_loss_list_noniid)), global_loss_list_noniid, 'r-', linewidth=2, label='Non-IID', alpha=0.8)
axes[0, 1].set_xlabel('Communication Round', fontsize=11)
axes[0, 1].set_ylabel('Loss', fontsize=11)
axes[0, 1].set_title('Loss: IID vs Non-IID', fontsize=13, fontweight='bold')
axes[0, 1].legend()
axes[0, 1].grid(True, alpha=0.3)

# 3. Divergence Comparison
axes[0, 2].plot(range(len(divergence_list)), divergence_list, 'b-', linewidth=2, label='IID', alpha=0.8)
axes[0, 2].plot(range(len(divergence_list_noniid)), divergence_list_noniid, 'r-', linewidth=2, label='Non-IID', alpha=0.8)
axes[0, 2].set_xlabel('Communication Round', fontsize=11)
axes[0, 2].set_ylabel('Weight Divergence', fontsize=11)
axes[0, 2].set_title('Model Stability: IID vs Non-IID', fontsize=13, fontweight='bold')
axes[0, 2].legend()
axes[0, 2].grid(True, alpha=0.3)

# 4. Accuracy Distribution Comparison
axes[1, 0].hist(global_acc_list, bins=20, alpha=0.6, label='IID', edgecolor='black', color='blue')
axes[1, 0].hist(global_acc_list_noniid, bins=20, alpha=0.6, label='Non-IID', edgecolor='black', color='red')
axes[1, 0].set_xlabel('Accuracy', fontsize=11)
axes[1, 0].set_ylabel('Frequency', fontsize=11)
axes[1, 0].set_title('Accuracy Distribution Comparison', fontsize=13, fontweight='bold')
axes[1, 0].legend()
axes[1, 0].grid(True, alpha=0.3)

# 5. Box Plot Comparison
axes[1, 1].boxplot([global_acc_list, global_acc_list_noniid], 
                    labels=['IID', 'Non-IID'],
                    patch_artist=True,
                    boxprops=dict(facecolor='lightblue', alpha=0.7),
                    medianprops=dict(color='red', linewidth=2))
axes[1, 1].set_ylabel('Accuracy', fontsize=11)
axes[1, 1].set_title('Accuracy Box Plot Comparison', fontsize=13, fontweight='bold')
axes[1, 1].grid(True, alpha=0.3)

# 6. Summary Statistics Table
axes[1, 2].axis('off')
comparison_text = f"""
COMPARATIVE ANALYSIS
════════════════════════════════════

IID Setting:
  Final Acc:     {global_acc_list[-1]:.4f}
  Max Acc:       {max(global_acc_list):.4f}
  Mean Acc:      {np.mean(global_acc_list):.4f}
  Std Acc:       {np.std(global_acc_list):.4f}
  Final Loss:    {global_loss_list[-1]:.4f}
  
Non-IID Setting:
  Final Acc:     {global_acc_list_noniid[-1]:.4f}
  Max Acc:       {max(global_acc_list_noniid):.4f}
  Mean Acc:      {np.mean(global_acc_list_noniid):.4f}
  Std Acc:       {np.std(global_acc_list_noniid):.4f}
  Final Loss:    {global_loss_list_noniid[-1]:.4f}

Performance Gap:
  Acc Difference: {abs(global_acc_list[-1] - global_acc_list_noniid[-1]):.4f}
  Loss Difference: {abs(global_loss_list[-1] - global_loss_list_noniid[-1]):.4f}

Convergence:
  IID Rounds:    {len(global_acc_list)}
  Non-IID Rounds: {len(global_acc_list_noniid)}
"""
axes[1, 2].text(0.1, 0.5, comparison_text, fontsize=10, family='monospace',
                verticalalignment='center', 
                bbox=dict(boxstyle='round', facecolor='lightgreen', alpha=0.3))

plt.tight_layout()
plt.savefig('fl_iid_vs_noniid_comparison.png', dpi=300, bbox_inches='tight')
plt.show()

print("\n" + "="*70)
print("COMPREHENSIVE COMPARISON: IID vs NON-IID")
print("="*70)
print(f"✓ Comparison visualization saved as 'fl_iid_vs_noniid_comparison.png'")
print("="*70)