## 1. Environment Setup and Data Preparation

The first step simulates a realistic federated environment. We use a **pathological Non-IID split** of the MNIST dataset, where each of the 100 clients is assigned samples from only **2 class labels**. This creates the data heterogeneity necessary to test the robustness of the FedProx algorithm.



In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
import numpy as np
import copy
import matplotlib.pyplot as plt
from scipy.stats import wilcoxon

# Set seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)

def get_mnist_non_iid(num_clients=100):
    """
    Partitions MNIST into 100 Non-IID clients.
    Each client receives data from exactly 2 classes.
    """
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])
    
    train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
    test_dataset = datasets.MNIST('./data', train=False, download=True, transform=transform)
    
    # Sort data by labels to create Non-IID shards
    indices = np.arange(len(train_dataset))
    labels = train_dataset.targets.numpy()
    indices_sorted = indices[np.argsort(labels)]
    
    # Divide into 200 shards (each client gets 2 shards)
    shards = np.split(indices_sorted, 200)
    np.random.shuffle(shards)
    
    client_indices = {i: np.concatenate((shards[2*i], shards[2*i+1]), axis=0) 
                      for i in range(num_clients)}
    
    return train_dataset, test_dataset, client_indices

# Initialize Data
train_data, test_data, client_map = get_mnist_non_iid()
test_loader = torch.utils.data.DataLoader(test_data, batch_size=1000, shuffle=False)


## 2. Model Architecture

We implement a 2-layer **Multi-Layer Perceptron (MLP)**. This lightweight model is standard for benchmarking FL algorithms on MNIST and ensures fast convergence during your one-week experiment.



In [None]:

class FederatedMLP(nn.Module):
    def __init__(self):
        super(FederatedMLP, self).__init__()
        self.fc1 = nn.Linear(28 * 28, 200)
        self.fc2 = nn.Linear(200, 100)
        self.fc3 = nn.Linear(100, 10)

    def forward(self, x):
        x = x.view(-1, 28 * 28)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return self.fc3(x)


## 3. FedProx Local Training & Global Evaluation

The **FedProx** baseline introduces a **proximal term** () to the local loss function. This regularization term penalizes local updates that deviate significantly from the global model, stabilizing training on heterogeneous data.


In [None]:
def evaluate(model, loader):
    model.eval()
    correct = 0
    with torch.no_grad():
        for data, target in loader:
            output = model(data)
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()
    return 100. * correct / len(loader.dataset)

def local_train_fedprox(model, global_model, train_loader, mu=0.01, epochs=1):
    model.train()
    optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
    
    for _ in range(epochs):
        for data, target in train_loader:
            optimizer.zero_grad()
            output = model(data)
            
            # Standard Cross Entropy Loss
            ce_loss = F.cross_entropy(output, target)
            
            # FedProx Proximal Term
            prox_term = 0.0
            for p, g_p in zip(model.parameters(), global_model.parameters()):
                prox_term += (p - g_p).norm(2)**2
            
            loss = ce_loss + (mu / 2) * prox_term
            loss.backward()
            optimizer.step()



## 4. Federated Orchestration (Baseline)

This loop coordinates the global communication rounds. Currently, it uses a **random selection** strategy, which serves as the SOTA baseline for comparison.



In [None]:
def run_federated_experiment(strategy_name="FedProx_Baseline", num_rounds=30, clients_per_round=10):
    global_model = FederatedMLP()
    accuracy_history = []
    
    print(f"Starting Experiment: {strategy_name}")
    
    for r in range(num_rounds):
        # 1. CLIENT SELECTION STEP 
        # (This is where you will plug in GA, PSO, or SA later)
        all_ids = list(client_map.keys())
        selected_ids = np.random.choice(all_ids, clients_per_round, replace=False)
        
        local_states = []
        
        # 2. LOCAL UPDATES
        for cid in selected_ids:
            local_model = copy.deepcopy(global_model)
            indices = client_map[cid]
            loader = torch.utils.data.DataLoader(train_data, batch_size=32, 
                                                 sampler=torch.utils.data.SubsetRandomSampler(indices))
            
            local_train_fedprox(local_model, global_model, loader, mu=0.01)
            local_states.append(local_model.state_dict())
            
        # 3. GLOBAL AGGREGATION (Weighted Averaging)
        global_dict = global_model.state_dict()
        for key in global_dict.keys():
            global_dict[key] = torch.stack([local_states[i][key] for i in range(len(local_states))], dim=0).mean(0)
        global_model.load_state_dict(global_dict)
        
        # 4. SERVER EVALUATION
        acc = evaluate(global_model, test_loader)
        accuracy_history.append(acc)
        if (r+1) % 5 == 0:
            print(f"Round {r+1}: Global Accuracy = {acc:.2f}%")
            
    return accuracy_history

# Run Baseline
fedprox_results = run_federated_experiment("FedProx_Baseline")



## 5. Visualization and Statistical Analysis

Following section V of your guidelines, we provide the **Convergence Curve** and the **Wilcoxon rank-sum test** boilerplate.




In [None]:

def plot_results(results_dict):
    plt.figure(figsize=(10, 6))
    for label, history in results_dict.items():
        plt.plot(history, label=label)
    plt.xlabel("Communication Rounds")
    plt.ylabel("Test Accuracy (%)")
    plt.title("Convergence Curves: SOTA vs Nature-Inspired Heuristics")
    plt.legend()
    plt.grid(True)
    plt.show()

# Boilerplate for Wilcoxon Test (Compare Baseline vs your eventual GA/PSO/SA)
# res = wilcoxon(fedprox_results, nia_results)
# print(f"P-value: {res.pvalue}")





### Integration Guide for NIA Step

To implement your **Genetic Algorithm**, **PSO**, or **SA**, you simply need to rewrite the "Client Selection Step" in Section 4.

1. **Fitness Function:** Use the global accuracy (or a proxy like validation loss) as the fitness value for a particular subset of clients.
2. **Search Space:** A binary string of length 100 (where '1' means a client is selected for that round).

Would you like me to create the specific **Fitness Function interface** for your metaheuristics now?