In [1]:
import pandas as pd
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader, Subset
from sklearn.preprocessing import LabelEncoder, StandardScaler
from sklearn.model_selection import train_test_split
import torch.nn as nn
import torch.nn.functional as F
from sklearn.feature_selection import f_classif


In [2]:
# For IoT IDS
csv_path = 'ACI/ACI-IoT-2023.csv'
df = pd.read_csv(csv_path)
#print(df.columns)
#print(df.shape)
df.replace([np.inf, -np.inf], np.nan, inplace=True)
df.dropna(inplace=True)
#print(df.shape)

In [3]:
print(df['Label'].value_counts())
# Now, remap your classes as before
def map_to_class(label):
    if label == 'Benign':
        return 'Benign'
    else:
        return 'Attack'
df['class'] = df['Label'].apply(map_to_class)
print(df['class'].value_counts())


Label
Port Scan             441260
Benign                327505
ICMP Flood            225234
Ping Sweep             71928
DNS Flood              46934
Vulnerability Scan     39533
OS Scan                37524
Slowloris              18537
SYN Flood              13857
Dictionary Attack       6379
UDP Flood                791
ARP Spoofing               5
Name: count, dtype: int64
class
Attack    901982
Benign    327505
Name: count, dtype: int64


In [4]:
exclude_cols = ['Flow ID', 'Src IP', 'Src Port', 'Dst IP','Dst Port', 'Protocol', 'Timestamp','Label','Connection Type','class']
feature_cols = [c for c in df.columns if c not in exclude_cols]
label_col = 'class'

In [5]:
feature_cols = [c for c in df.columns if c not in exclude_cols]
X = df[feature_cols].values
y = df[label_col].values

le = LabelEncoder()
y = le.fit_transform(y)
scaler = StandardScaler()
X = scaler.fit_transform(X)

In [6]:
from sklearn.feature_selection import VarianceThreshold

var_thresh = VarianceThreshold(threshold=0.0)
X_var = var_thresh.fit_transform(X)
print(f"Constant features removed: {X.shape[1] - X_var.shape[1]}")
X = X_var


Constant features removed: 6


In [7]:
class TabularDataset(Dataset):
    def __init__(self, X, y):
        self.X = torch.tensor(X, dtype=torch.float32)
        self.y = torch.tensor(y, dtype=torch.long)
    def __len__(self):
        return len(self.y)
    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

full_dataset = TabularDataset(X, y)
train_idx, test_idx = train_test_split(np.arange(len(full_dataset)), test_size=0.2, stratify=y, random_state=6)
train_dataset = Subset(full_dataset, train_idx)
test_dataset = Subset(full_dataset, test_idx)

def partition_tabular_dataset(dataset, labels, train_idx, num_clients=10, alpha=0.5):
    np.random.seed(6)
    targets = np.array(labels)[train_idx]
    num_classes = np.max(targets) + 1
    idxs = np.arange(len(targets))
    client_idx = [[] for _ in range(num_clients)]
    for c in range(num_classes):
        idx_c = idxs[targets == c]
        np.random.shuffle(idx_c)
        proportions = np.random.dirichlet([alpha]*num_clients)
        proportions = (np.cumsum(proportions) * len(idx_c)).astype(int)[:-1]
        split_idxs = np.split(idx_c, proportions)
        for i, idx in enumerate(split_idxs):
            client_idx[i].extend(idx)
    return client_idx

num_clients = 10
alpha = 0.5
client_indices = partition_tabular_dataset(train_dataset, y, train_idx, num_clients, alpha)

client_data_np = []
for i in range(num_clients):
    idxs = client_indices[i]
    X_client = X[train_idx][idxs]
    y_client = y[train_idx][idxs]
    client_data_np.append((X_client, y_client))

for i, (Xc, yc) in enumerate(client_data_np):
    print(f"Client {i+1} class distribution:", np.bincount(yc))


Client 1 class distribution: [43268    23]
Client 2 class distribution: [ 1667 38408]
Client 3 class distribution: [278907     37]
Client 4 class distribution: [98226 44756]
Client 5 class distribution: [88716 12632]
Client 6 class distribution: [14259   154]
Client 7 class distribution: [23414 39636]
Client 8 class distribution: [116528   6864]
Client 9 class distribution: [54051 42726]
Client 10 class distribution: [ 2549 76768]


In [8]:
def compute_fisher_scores(X, y):
    scores, _ = f_classif(X, y)
    # Normalize scores to [0,1]
    min_val = np.min(scores)
    max_val = np.max(scores)
    if max_val > min_val:
        normalized_scores = (scores - min_val) / (max_val - min_val)
    else:
        normalized_scores = np.zeros_like(scores)
    return normalized_scores

def compute_corr_matrix(X):
    corr = np.corrcoef(X, rowvar=False)
    return np.abs(corr)


In [9]:
import numpy as np

def evaluate_feature_subset(subset, fisher_scores, corr_matrix, penalty_lambda=0.7):
    if len(subset) == 0:
        return 0
    fisher_sum = np.sum(fisher_scores[subset])
    if len(subset) > 1:
        corr_penalty = np.sum(corr_matrix[np.ix_(subset, subset)]) - np.sum(np.diag(corr_matrix[subset][:, subset]))
        corr_penalty /= 2
    else:
        corr_penalty = 0.0
    return penalty_lambda * fisher_sum - (1 - penalty_lambda) * corr_penalty

def one_step_binary_firefly(
    firefly_mask_prev, global_mask_prev, local_best_mask_prev,
    fisher_scores, corr_matrix, penalty_lambda=0.7, p_global=0.3, p_local=0.3, mutation_rate=0.05, verbose=False
):
    n_features = len(firefly_mask_prev)
    new_mask = firefly_mask_prev.copy()
    for i in range(n_features):
        r = np.random.rand()
        if r < p_global:
            new_mask[i] = global_mask_prev[i]
        elif r < p_global + p_local:
            new_mask[i] = local_best_mask_prev[i]
        elif np.random.rand() < mutation_rate:
            new_mask[i] = 1 - new_mask[i]  # mutate

    # Optional: flip one bit with small probability for extra exploration
    if np.random.rand() < 0.2:
        idx = np.random.randint(n_features)
        new_mask[idx] = 1 - new_mask[idx]

    if verbose:
        sel = np.where(new_mask)[0]
        fit = evaluate_feature_subset(sel, fisher_scores, corr_matrix, penalty_lambda)
        print(f"    - New mask: {np.sum(new_mask)} features, Fitness: {fit:.4f}")

    return new_mask


In [11]:
n_feat_select_rounds = 20
n_fireflies = 20           # Number of fireflies per client
n_features = X.shape[1]
num_clients = len(client_data_np)
rho_start, rho_end = 0.2, 0.8
penalty_lambda = 0.9

# Precompute Fisher scores and correlation matrix for each client
client_fisher_scores = []
client_corr_matrix = []
for Xc, yc in client_data_np:
    fisher_scores = compute_fisher_scores(Xc, yc)
    corr_matrix = compute_corr_matrix(Xc)
    client_fisher_scores.append(fisher_scores)
    client_corr_matrix.append(corr_matrix)

# Initialize fireflies for each client at round 1
client_fireflies = []
client_local_bests = []
for cid in range(num_clients):
    fireflies = []
    for _ in range(n_fireflies):
        mask = np.random.choice([0, 1], size=n_features)
        if np.sum(mask) == 0:
            mask[np.random.randint(n_features)] = 1  # Ensure at least one feature is selected
        fireflies.append(mask)
    # Evaluate and store best
    best_fitness = -np.inf
    best_mask = None
    for mask in fireflies:
        sel = np.where(mask)[0]
        fit = evaluate_feature_subset(sel, client_fisher_scores[cid], client_corr_matrix[cid], penalty_lambda)
        if fit > best_fitness or best_mask is None:
            best_fitness = fit
            best_mask = mask.copy()
    # Fallback: all features if somehow none was found
    if best_mask is None:
        best_mask = np.ones(n_features, dtype=int)
    client_fireflies.append(fireflies)
    client_local_bests.append(best_mask.copy())

# Start with all features selected in global mask
global_mask = np.ones(n_features, dtype=int)

for round_fs in range(n_feat_select_rounds):
    print(f"\n================ Federated BFA Round {round_fs+1} ================")
    # Linear schedule for rho
    rho = rho_start + (rho_end - rho_start) * (round_fs / (n_feat_select_rounds - 1))
    print(f"  Adaptive rho for this round: {rho:.2f}")

    client_best_masks = []
    # For each client, update fireflies and find new local best
    for cid in range(num_clients):
        fireflies = client_fireflies[cid]
        fisher_scores = client_fisher_scores[cid]
        corr_matrix = client_corr_matrix[cid]
        local_best = client_local_bests[cid]
        new_fireflies = []
        best_fitness = -np.inf
        best_mask = None
        for f in range(n_fireflies):
            new_mask = one_step_binary_firefly(
                fireflies[f],
                global_mask,
                local_best,
                fisher_scores,
                corr_matrix,
                penalty_lambda=penalty_lambda,
                verbose=True
            )
            # Ensure at least one feature
            if np.sum(new_mask) == 0:
                new_mask[np.random.randint(n_features)] = 1
            new_fireflies.append(new_mask)
            sel = np.where(new_mask)[0]
            fit = evaluate_feature_subset(sel, fisher_scores, corr_matrix, penalty_lambda)
            if fit > best_fitness or best_mask is None:
                best_fitness = fit
                best_mask = new_mask.copy()
        # Fallback: all features if somehow none was found
        if best_mask is None:
            best_mask = np.ones(n_features, dtype=int)
        # Update client's fireflies and local best
        client_fireflies[cid] = new_fireflies
        client_local_bests[cid] = best_mask.copy()
        client_best_masks.append(best_mask.copy())
    client_best_masks = np.array(client_best_masks)
    vote_counts = np.sum(client_best_masks, axis=0)
    vote_mask = (vote_counts >= (rho * num_clients)).astype(int)
    print(f"=== End of Round {round_fs+1}: Vote mask selects {vote_mask.sum()} features (rho: {rho:.2f})\n"
          f"    Indices: {np.where(vote_mask)[0].tolist()}")
    global_mask = vote_mask.copy()

selected_indices = np.where(global_mask == 1)[0]
print(f"\nFinal federated feature count: {len(selected_indices)}")
selected_feature_names = [feature_cols[i] for i in selected_indices]
print("Selected feature names:", selected_feature_names)



  Adaptive rho for this round: 0.20
    - New mask: 45 features, Fitness: -13.9524
    - New mask: 40 features, Fitness: -12.0330
    - New mask: 47 features, Fitness: -17.0082
    - New mask: 55 features, Fitness: -22.6054
    - New mask: 48 features, Fitness: -17.7321
    - New mask: 34 features, Fitness: -8.9148
    - New mask: 43 features, Fitness: -13.9229
    - New mask: 44 features, Fitness: -13.5382
    - New mask: 42 features, Fitness: -12.5026
    - New mask: 43 features, Fitness: -14.3377
    - New mask: 40 features, Fitness: -11.6040
    - New mask: 49 features, Fitness: -20.4368
    - New mask: 45 features, Fitness: -17.1015
    - New mask: 38 features, Fitness: -7.7240
    - New mask: 46 features, Fitness: -16.0311
    - New mask: 41 features, Fitness: -9.9969
    - New mask: 47 features, Fitness: -15.3677
    - New mask: 44 features, Fitness: -14.8329
    - New mask: 50 features, Fitness: -20.4227
    - New mask: 43 features, Fitness: -14.2223
    - New mask: 48 feature

In [12]:
selected_indices = np.where(global_mask == 1)[0]
print(f"\nFinal federated feature count: {len(selected_indices)}")
selected_feature_names = [feature_cols[i] for i in selected_indices]
print("Selected feature names:", selected_feature_names)


Final federated feature count: 28
Selected feature names: ['Total Bwd packets', 'Total Length of Bwd Packet', 'Fwd Packet Length Min', 'Fwd Packet Length Std', 'Bwd Packet Length Std', 'Flow Bytes/s', 'Flow Packets/s', 'Fwd IAT Total', 'Bwd IAT Total', 'Bwd IAT Std', 'Bwd IAT Min', 'Fwd PSH Flags', 'Bwd PSH Flags', 'Fwd Header Length', 'Bwd Header Length', 'Packet Length Std', 'Packet Length Variance', 'FIN Flag Count', 'PSH Flag Count', 'ACK Flag Count', 'URG Flag Count', 'CWR Flag Count', 'Fwd Bytes/Bulk Avg', 'Bwd Bulk Rate Avg', 'Subflow Fwd Packets', 'FWD Init Win Bytes', 'Bwd Init Win Bytes', 'Active Mean']


In [13]:
X_sel = X[:, selected_indices]
input_dim = X_sel.shape[1]
full_dataset = TabularDataset(X_sel, y)
train_dataset = Subset(full_dataset, train_idx)
test_dataset = Subset(full_dataset, test_idx)

client_loaders = []
for i in range(num_clients):
    idxs = client_indices[i]
    client_subset = Subset(train_dataset, idxs)
    client_loader = DataLoader(client_subset, batch_size=128, shuffle=True, drop_last=True)
    client_loaders.append(client_loader)
test_loader = DataLoader(test_dataset, batch_size=512, shuffle=False)


In [14]:
num_classes = len(np.unique(y))

class TabularMLP(nn.Module):
    def __init__(self, input_dim, hidden_dim=128, num_classes=2):
        super(TabularMLP, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.bn1 = nn.BatchNorm1d(hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.bn2 = nn.BatchNorm1d(hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, num_classes)
        self.dropout = nn.Dropout(0.3)
    def forward(self, x, return_features=False):
        x = F.relu(self.bn1(self.fc1(x)))
        x = self.dropout(x)
        features = F.relu(self.bn2(self.fc2(x)))
        x = self.dropout(features)
        out = self.fc3(x)
        if return_features:
            return out, features
        else:
            return out


In [15]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def train_one_client(model, loader, epochs=1, lr=0.01):
    model = model.to(device)
    model.train()
    optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9)
    criterion = nn.CrossEntropyLoss()
    for epoch in range(epochs):
        for data, target in loader:
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
    return model.cpu()

def evaluate_local(model, loader):
    model = model.to(device)
    model.eval()
    correct = total = 0
    with torch.no_grad():
        for data, target in loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            pred = output.argmax(dim=1)
            correct += (pred == target).sum().item()
            total += data.size(0)
    acc = 100. * correct / total
    return acc

def test_model(model, loader):
    model = model.to(device)
    model.eval()
    correct = total = 0
    with torch.no_grad():
        for data, target in loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            pred = output.argmax(dim=1)
            correct += (pred == target).sum().item()
            total += data.size(0)
    acc = 100. * correct / total
    return acc

def average_weights(weight_list):
    avg_weights = {}
    for key in weight_list[0].keys():
        avg_weights[key] = sum([w[key] for w in weight_list]) / len(weight_list)
    return avg_weights

from collections import defaultdict

global_model = TabularMLP(input_dim=input_dim, num_classes=num_classes)
global_model.load_state_dict(global_model.state_dict())

num_rounds = 20
for rnd in range(1, num_rounds + 1):
    adaptive_epochs = max(1, int(10 - 9 * (rnd-1) / (num_rounds-1)))
    print(f"\n{'='*30}\nFederated Round {rnd} (Local Epochs: {adaptive_epochs})\n{'='*30}")
    local_weights = []
    client_accuracies_before = []
    client_accuracies_after = []
    client_sample_counts = []

    for client_id in range(num_clients):
        num_samples = len(client_loaders[client_id].dataset)
        acc_before = evaluate_local(global_model, client_loaders[client_id])
        local_model = TabularMLP(input_dim=input_dim, num_classes=num_classes)
        local_model.load_state_dict(global_model.state_dict())
        local_model = train_one_client(local_model, client_loaders[client_id], epochs=adaptive_epochs)
        acc_after = evaluate_local(local_model, client_loaders[client_id])
        local_weights.append(local_model.state_dict())
        client_sample_counts.append(num_samples)
        client_accuracies_before.append(acc_before)
        client_accuracies_after.append(acc_after)
        print(f"  Client {client_id+1:2d} | Samples: {num_samples:4d} | Acc Before: {acc_before:5.2f}% | Acc After: {acc_after:5.2f}%")

    global_model.load_state_dict(average_weights(local_weights))
    acc_global = test_model(global_model, test_loader)
    print(f"\n[Round {rnd}] Global Test Accuracy: {acc_global:.2f}%")
    print(f"Client Acc BEFORE (mean ± std): {np.mean(client_accuracies_before):.2f}% ± {np.std(client_accuracies_before):.2f}%")
    print(f"Client Acc AFTER  (mean ± std): {np.mean(client_accuracies_after):.2f}% ± {np.std(client_accuracies_after):.2f}%")
    print(f"Client sample count (min, max): {min(client_sample_counts)}, {max(client_sample_counts)}")



Federated Round 1 (Local Epochs: 10)
  Client  1 | Samples: 43291 | Acc Before:  4.20% | Acc After: 99.95%
  Client  2 | Samples: 40075 | Acc Before: 61.92% | Acc After: 98.14%
  Client  3 | Samples: 278944 | Acc Before:  4.05% | Acc After: 99.99%
  Client  4 | Samples: 142982 | Acc Before: 22.80% | Acc After: 97.46%
  Client  5 | Samples: 101348 | Acc Before: 11.43% | Acc After: 98.38%
  Client  6 | Samples: 14413 | Acc Before:  4.76% | Acc After: 99.28%
  Client  7 | Samples: 63050 | Acc Before: 42.22% | Acc After: 96.97%
  Client  8 | Samples: 123392 | Acc Before:  7.47% | Acc After: 98.98%
  Client  9 | Samples: 96777 | Acc Before: 30.62% | Acc After: 96.87%
  Client 10 | Samples: 79317 | Acc Before: 62.55% | Acc After: 98.15%

[Round 1] Global Test Accuracy: 96.13%
Client Acc BEFORE (mean ± std): 25.20% ± 22.09%
Client Acc AFTER  (mean ± std): 98.42% ± 1.07%
Client sample count (min, max): 14413, 278944

Federated Round 2 (Local Epochs: 9)
  Client  1 | Samples: 43291 | Acc Befor

In [16]:
import numpy as np
from sklearn.metrics import confusion_matrix

def test_modelv2(model, loader):
    model = model.to(device)
    model.eval()
    correct = total = 0
    all_targets = []
    all_preds = []
    with torch.no_grad():
        for data, target in loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            pred = output.argmax(dim=1)
            correct += (pred == target).sum().item()
            total += data.size(0)
            all_targets.extend(target.cpu().numpy())
            all_preds.extend(pred.cpu().numpy())
    acc = 100. * correct / total
    cm = confusion_matrix(all_targets, all_preds)
    return acc, cm

# Usage
acc, cm = test_modelv2(global_model, test_loader)
print(f"Test Accuracy: {acc:.2f}%")
print("Confusion Matrix:\n", cm)


Test Accuracy: 97.15%
Confusion Matrix:
 [[179956    441]
 [  6559  58942]]


In [17]:
import numpy as np
from scipy.cluster.hierarchy import linkage, fcluster
from scipy.spatial.distance import squareform

def frhc_local_feature_selection(X, max_clusters=None, comp_feat=1):
    """
    Local representative feature selection by hierarchical clustering of features.
    
    Parameters:
        X: [n_samples, n_features] numpy array (client's local data)
        max_clusters: int or None, maximum clusters to try for optimal selection
        comp_feat: int, number of compensation features to add

    Returns:
        selected_feature_indices: list of selected feature indices
    """
    n_features = X.shape[1]
    # Step 1: Compute absolute correlation distance between features
    corr_matrix = np.corrcoef(X, rowvar=False)
    dist_matrix = 1 - np.abs(corr_matrix)
    # Ensure distance matrix is valid
    np.fill_diagonal(dist_matrix, 0)
    # Convert to condensed form for linkage
    condensed = squareform(dist_matrix, checks=False)
    # Step 2: Hierarchical clustering
    Z = linkage(condensed, method='average')
    # Step 3: Optimal number of clusters (can be determined by a method, here use max_clusters or sqrt rule)
    if max_clusters is None:
        K = int(np.sqrt(n_features))
    else:
        K = min(max_clusters, n_features)
    clusters = fcluster(Z, K, criterion='maxclust')
    # Step 4: Find the two largest clusters
    cluster_sizes = [(c, np.sum(clusters == c)) for c in np.unique(clusters)]
    cluster_sizes.sort(key=lambda x: x[1], reverse=True)
    selected_features = []
    for i in range(min(2, len(cluster_sizes))):
        c = cluster_sizes[i][0]
        selected_features.extend(np.where(clusters == c)[0].tolist())
    # Step 5: Optionally add compensation feature(s)
    if comp_feat > 0:
        feature_counts = [(c, np.sum(clusters == c)) for c in np.unique(clusters)]
        cluster_sorted = sorted(feature_counts, key=lambda x: x[1], reverse=True)
        # Add features from next largest clusters if needed
        for i in range(2, min(2 + comp_feat, len(cluster_sorted))):
            c = cluster_sorted[i][0]
            selected_features.append(np.where(clusters == c)[0][0])
    # Remove duplicates
    selected_features = list(sorted(set(selected_features)))
    return selected_features


In [18]:
def frhc_global_intersection(selected_lists):
    """
    Compute global overlapping federated features as intersection of local sets.
    Parameters:
        selected_lists: list of list of feature indices (from each client)
    Returns:
        final_indices: list of feature indices present in all clients
    """
    # Convert all to set for intersection
    final_indices = set(selected_lists[0])
    for feat_set in selected_lists[1:]:
        final_indices &= set(feat_set)
    return sorted(list(final_indices))


In [29]:
# Suppose client_data_np is a list of (X_local, y_local) for all clients
selected_lists = []
for Xc, yc in client_data_np:
    feats = frhc_local_feature_selection(Xc,max_clusters=12,comp_feat=9)
    selected_lists.append(feats)

# Global intersection at the server
global_frhc_indices = frhc_global_intersection(selected_lists)
print("Count:",len(global_frhc_indices))
print("Global federated feature indices (FRHC):", global_frhc_indices)
print("Selected feature names:", [feature_cols[i] for i in global_frhc_indices])


Count: 27
Global federated feature indices (FRHC): [0, 1, 6, 9, 11, 12, 13, 15, 16, 17, 19, 20, 22, 23, 30, 36, 37, 38, 39, 46, 49, 51, 57, 62, 66, 67, 69]
Selected feature names: ['Flow Duration', 'Total Fwd Packet', 'Fwd Packet Length Min', 'Bwd Packet Length Max', 'Bwd Packet Length Mean', 'Bwd Packet Length Std', 'Flow Bytes/s', 'Flow IAT Mean', 'Flow IAT Std', 'Flow IAT Max', 'Fwd IAT Total', 'Fwd IAT Mean', 'Fwd IAT Max', 'Fwd IAT Min', 'Bwd PSH Flags', 'Bwd Packets/s', 'Packet Length Min', 'Packet Length Max', 'Packet Length Mean', 'ACK Flag Count', 'ECE Flag Count', 'Average Packet Size', 'Bwd Bytes/Bulk Avg', 'Subflow Bwd Packets', 'Fwd Act Data Pkts', 'Fwd Seg Size Min', 'Active Std']


In [30]:
selected_indices=global_frhc_indices
X_sel = X[:, selected_indices]
input_dim = X_sel.shape[1]
full_dataset = TabularDataset(X_sel, y)
train_dataset = Subset(full_dataset, train_idx)
test_dataset = Subset(full_dataset, test_idx)

client_loaders = []
for i in range(num_clients):
    idxs = client_indices[i]
    client_subset = Subset(train_dataset, idxs)
    client_loader = DataLoader(client_subset, batch_size=128, shuffle=True, drop_last=True)
    client_loaders.append(client_loader)
test_loader = DataLoader(test_dataset, batch_size=512, shuffle=False)


In [31]:
num_classes = len(np.unique(y))

class TabularMLP(nn.Module):
    def __init__(self, input_dim, hidden_dim=128, num_classes=2):
        super(TabularMLP, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.bn1 = nn.BatchNorm1d(hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.bn2 = nn.BatchNorm1d(hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, num_classes)
        self.dropout = nn.Dropout(0.3)
    def forward(self, x, return_features=False):
        x = F.relu(self.bn1(self.fc1(x)))
        x = self.dropout(x)
        features = F.relu(self.bn2(self.fc2(x)))
        x = self.dropout(features)
        out = self.fc3(x)
        if return_features:
            return out, features
        else:
            return out


In [32]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def train_one_client(model, loader, epochs=1, lr=0.01):
    model = model.to(device)
    model.train()
    optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9)
    criterion = nn.CrossEntropyLoss()
    for epoch in range(epochs):
        for data, target in loader:
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
    return model.cpu()

def evaluate_local(model, loader):
    model = model.to(device)
    model.eval()
    correct = total = 0
    with torch.no_grad():
        for data, target in loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            pred = output.argmax(dim=1)
            correct += (pred == target).sum().item()
            total += data.size(0)
    acc = 100. * correct / total
    return acc

def test_model(model, loader):
    model = model.to(device)
    model.eval()
    correct = total = 0
    with torch.no_grad():
        for data, target in loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            pred = output.argmax(dim=1)
            correct += (pred == target).sum().item()
            total += data.size(0)
    acc = 100. * correct / total
    return acc

def average_weights(weight_list):
    avg_weights = {}
    for key in weight_list[0].keys():
        avg_weights[key] = sum([w[key] for w in weight_list]) / len(weight_list)
    return avg_weights

from collections import defaultdict

global_model = TabularMLP(input_dim=input_dim, num_classes=num_classes)
global_model.load_state_dict(global_model.state_dict())

num_rounds = 20
for rnd in range(1, num_rounds + 1):
    adaptive_epochs = max(1, int(10 - 9 * (rnd-1) / (num_rounds-1)))
    print(f"\n{'='*30}\nFederated Round {rnd} (Local Epochs: {adaptive_epochs})\n{'='*30}")
    local_weights = []
    client_accuracies_before = []
    client_accuracies_after = []
    client_sample_counts = []

    for client_id in range(num_clients):
        num_samples = len(client_loaders[client_id].dataset)
        acc_before = evaluate_local(global_model, client_loaders[client_id])
        local_model = TabularMLP(input_dim=input_dim, num_classes=num_classes)
        local_model.load_state_dict(global_model.state_dict())
        local_model = train_one_client(local_model, client_loaders[client_id], epochs=adaptive_epochs)
        acc_after = evaluate_local(local_model, client_loaders[client_id])
        local_weights.append(local_model.state_dict())
        client_sample_counts.append(num_samples)
        client_accuracies_before.append(acc_before)
        client_accuracies_after.append(acc_after)
        print(f"  Client {client_id+1:2d} | Samples: {num_samples:4d} | Acc Before: {acc_before:5.2f}% | Acc After: {acc_after:5.2f}%")

    global_model.load_state_dict(average_weights(local_weights))
    acc_global = test_model(global_model, test_loader)
    print(f"\n[Round {rnd}] Global Test Accuracy: {acc_global:.2f}%")
    print(f"Client Acc BEFORE (mean ± std): {np.mean(client_accuracies_before):.2f}% ± {np.std(client_accuracies_before):.2f}%")
    print(f"Client Acc AFTER  (mean ± std): {np.mean(client_accuracies_after):.2f}% ± {np.std(client_accuracies_after):.2f}%")
    print(f"Client sample count (min, max): {min(client_sample_counts)}, {max(client_sample_counts)}")



Federated Round 1 (Local Epochs: 10)
  Client  1 | Samples: 43291 | Acc Before:  3.69% | Acc After: 99.95%
  Client  2 | Samples: 40075 | Acc Before: 65.23% | Acc After: 96.43%
  Client  3 | Samples: 278944 | Acc Before:  3.54% | Acc After: 99.99%
  Client  4 | Samples: 142982 | Acc Before: 23.87% | Acc After: 93.29%
  Client  5 | Samples: 101348 | Acc Before: 11.63% | Acc After: 95.07%
  Client  6 | Samples: 14413 | Acc Before:  4.39% | Acc After: 99.14%
  Client  7 | Samples: 63050 | Acc Before: 43.87% | Acc After: 87.71%
  Client  8 | Samples: 123392 | Acc Before:  7.08% | Acc After: 97.58%
  Client  9 | Samples: 96777 | Acc Before: 32.15% | Acc After: 91.32%
  Client 10 | Samples: 79317 | Acc Before: 66.04% | Acc After: 97.18%

[Round 1] Global Test Accuracy: 84.40%
Client Acc BEFORE (mean ± std): 26.15% ± 23.53%
Client Acc AFTER  (mean ± std): 95.76% ± 3.79%
Client sample count (min, max): 14413, 278944

Federated Round 2 (Local Epochs: 9)
  Client  1 | Samples: 43291 | Acc Befor

In [33]:
# Usage
acc, cm = test_modelv2(global_model, test_loader)
print(f"Test Accuracy: {acc:.2f}%")
print("Confusion Matrix:\n", cm)

Test Accuracy: 87.78%
Confusion Matrix:
 [[180074    323]
 [ 29730  35771]]


In [34]:
X_sel = X
input_dim = X_sel.shape[1]
full_dataset = TabularDataset(X_sel, y)
train_dataset = Subset(full_dataset, train_idx)
test_dataset = Subset(full_dataset, test_idx)

client_loaders = []
for i in range(num_clients):
    idxs = client_indices[i]
    client_subset = Subset(train_dataset, idxs)
    client_loader = DataLoader(client_subset, batch_size=128, shuffle=True, drop_last=True)
    client_loaders.append(client_loader)
test_loader = DataLoader(test_dataset, batch_size=512, shuffle=False)


In [35]:
num_classes = len(np.unique(y))

class TabularMLP(nn.Module):
    def __init__(self, input_dim, hidden_dim=128, num_classes=2):
        super(TabularMLP, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.bn1 = nn.BatchNorm1d(hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.bn2 = nn.BatchNorm1d(hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, num_classes)
        self.dropout = nn.Dropout(0.3)
    def forward(self, x, return_features=False):
        x = F.relu(self.bn1(self.fc1(x)))
        x = self.dropout(x)
        features = F.relu(self.bn2(self.fc2(x)))
        x = self.dropout(features)
        out = self.fc3(x)
        if return_features:
            return out, features
        else:
            return out


In [36]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def train_one_client(model, loader, epochs=1, lr=0.01):
    model = model.to(device)
    model.train()
    optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9)
    criterion = nn.CrossEntropyLoss()
    for epoch in range(epochs):
        for data, target in loader:
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
    return model.cpu()

def evaluate_local(model, loader):
    model = model.to(device)
    model.eval()
    correct = total = 0
    with torch.no_grad():
        for data, target in loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            pred = output.argmax(dim=1)
            correct += (pred == target).sum().item()
            total += data.size(0)
    acc = 100. * correct / total
    return acc

def test_model(model, loader):
    model = model.to(device)
    model.eval()
    correct = total = 0
    with torch.no_grad():
        for data, target in loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            pred = output.argmax(dim=1)
            correct += (pred == target).sum().item()
            total += data.size(0)
    acc = 100. * correct / total
    return acc

def average_weights(weight_list):
    avg_weights = {}
    for key in weight_list[0].keys():
        avg_weights[key] = sum([w[key] for w in weight_list]) / len(weight_list)
    return avg_weights

from collections import defaultdict

global_model = TabularMLP(input_dim=input_dim, num_classes=num_classes)
global_model.load_state_dict(global_model.state_dict())

num_rounds = 20
for rnd in range(1, num_rounds + 1):
    adaptive_epochs = max(1, int(10 - 9 * (rnd-1) / (num_rounds-1)))
    print(f"\n{'='*30}\nFederated Round {rnd} (Local Epochs: {adaptive_epochs})\n{'='*30}")
    local_weights = []
    client_accuracies_before = []
    client_accuracies_after = []
    client_sample_counts = []

    for client_id in range(num_clients):
        num_samples = len(client_loaders[client_id].dataset)
        acc_before = evaluate_local(global_model, client_loaders[client_id])
        local_model = TabularMLP(input_dim=input_dim, num_classes=num_classes)
        local_model.load_state_dict(global_model.state_dict())
        local_model = train_one_client(local_model, client_loaders[client_id], epochs=adaptive_epochs)
        acc_after = evaluate_local(local_model, client_loaders[client_id])
        local_weights.append(local_model.state_dict())
        client_sample_counts.append(num_samples)
        client_accuracies_before.append(acc_before)
        client_accuracies_after.append(acc_after)
        print(f"  Client {client_id+1:2d} | Samples: {num_samples:4d} | Acc Before: {acc_before:5.2f}% | Acc After: {acc_after:5.2f}%")

    global_model.load_state_dict(average_weights(local_weights))
    acc_global = test_model(global_model, test_loader)
    print(f"\n[Round {rnd}] Global Test Accuracy: {acc_global:.2f}%")
    print(f"Client Acc BEFORE (mean ± std): {np.mean(client_accuracies_before):.2f}% ± {np.std(client_accuracies_before):.2f}%")
    print(f"Client Acc AFTER  (mean ± std): {np.mean(client_accuracies_after):.2f}% ± {np.std(client_accuracies_after):.2f}%")
    print(f"Client sample count (min, max): {min(client_sample_counts)}, {max(client_sample_counts)}")



Federated Round 1 (Local Epochs: 10)
  Client  1 | Samples: 43291 | Acc Before: 97.53% | Acc After: 99.95%
  Client  2 | Samples: 40075 | Acc Before:  8.10% | Acc After: 98.43%
  Client  3 | Samples: 278944 | Acc Before: 97.64% | Acc After: 99.99%
  Client  4 | Samples: 142982 | Acc Before: 68.42% | Acc After: 97.95%
  Client  5 | Samples: 101348 | Acc Before: 86.01% | Acc After: 98.77%
  Client  6 | Samples: 14413 | Acc Before: 96.55% | Acc After: 99.55%
  Client  7 | Samples: 63050 | Acc Before: 38.98% | Acc After: 97.38%
  Client  8 | Samples: 123392 | Acc Before: 92.56% | Acc After: 99.11%
  Client  9 | Samples: 96777 | Acc Before: 56.40% | Acc After: 97.34%
  Client 10 | Samples: 79317 | Acc Before:  7.25% | Acc After: 98.49%

[Round 1] Global Test Accuracy: 93.43%
Client Acc BEFORE (mean ± std): 64.95% ± 34.14%
Client Acc AFTER  (mean ± std): 98.70% ± 0.92%
Client sample count (min, max): 14413, 278944

Federated Round 2 (Local Epochs: 9)
  Client  1 | Samples: 43291 | Acc Befor

In [37]:
# Usage
acc, cm = test_modelv2(global_model, test_loader)
print(f"Test Accuracy: {acc:.2f}%")
print("Confusion Matrix:\n", cm)

Test Accuracy: 98.14%
Confusion Matrix:
 [[179796    601]
 [  3970  61531]]
