In [1]:
import pandas as pd
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader, Subset
from sklearn.preprocessing import LabelEncoder, StandardScaler
from sklearn.model_selection import train_test_split
import torch.nn as nn
import torch.nn.functional as F
from sklearn.feature_selection import f_classif


In [2]:
# For IoT IDS
csv_path = 'wustl_corrected.csv'
df = pd.read_csv(csv_path,low_memory=False)
#print(df.columns)
#print(df.shape)
df.replace([np.inf, -np.inf], np.nan, inplace=True)
df.dropna(inplace=True)
#print(df.shape)

In [3]:
df.columns

Index(['StartTime', 'LastTime', 'SrcAddr', 'DstAddr', 'Mean', 'Sport', 'Dport',
       'SrcPkts', 'DstPkts', 'TotPkts', 'DstBytes', 'SrcBytes', 'TotBytes',
       'SrcLoad', 'DstLoad', 'Load', 'SrcRate', 'DstRate', 'Rate', 'SrcLoss',
       'DstLoss', 'Loss', 'pLoss', 'SrcJitter', 'DstJitter', 'SIntPkt',
       'DIntPkt', 'Proto', 'Dur', 'TcpRtt', 'Sum', 'Min', 'Max', 'sDSb',
       'sTtl', 'dTtl', 'sIpId', 'dIpId', 'SAppBytes', 'DAppBytes',
       'TotAppByte', 'SynAck', 'RunTime', 'sTos', 'SrcJitAct', 'DstJitAct',
       'Traffic', 'Target', 'IdleTime'],
      dtype='object')

In [4]:
df.head()

Unnamed: 0,StartTime,LastTime,SrcAddr,DstAddr,Mean,Sport,Dport,SrcPkts,DstPkts,TotPkts,...,DAppBytes,TotAppByte,SynAck,RunTime,sTos,SrcJitAct,DstJitAct,Traffic,Target,IdleTime
0,2019-08-19 09:46:08,2019-08-19 14:14:18,0,0,0,0,0,0,0,0,...,0,0,0.0,0.0,0,0.0,0.0,normal,0,0.0
1,2019-08-19 13:24:34,2019-08-19 14:14:18,0,0,0,0,0,0,0,0,...,0,0,0.0,0.0,0,0.0,0.0,normal,0,0.0
2,2019-08-19 11:05:18,2019-08-19 11:04:18,0,14740,0,0,2,90864,11501,90864,...,0,2019535332,0.0,0.0,0,0.0,0.0,normal,0,0.0
3,2019-08-19 12:30:18,2019-08-19 12:29:18,0,15046,0,0,2,11560267,154248,11560267,...,0,3488030376,0.0,0.0,0,0.0,0.0,normal,0,0.0
4,2019-08-19 11:10:18,2019-08-19 11:09:18,0,16274,0,0,2,93115,14011,93115,...,0,2138927796,0.0,0.0,0,0.0,0.0,normal,0,0.0


In [5]:
print(df['Target'].value_counts())
# Now, remap your classes as before


Target
0    1107448
1      87016
Name: count, dtype: int64


In [6]:
exclude_cols = ['StartTime', 'LastTime', 'SrcAddr', 'DstAddr','Traffic','Target']
feature_cols = [c for c in df.columns if c not in exclude_cols]
label_col = 'Target'
print(len(feature_cols))

43


In [7]:
feature_cols = [c for c in df.columns if c not in exclude_cols]
X = df[feature_cols].values
y = df[label_col].values

le = LabelEncoder()
y = le.fit_transform(y)
scaler = StandardScaler()
X = scaler.fit_transform(X)

In [8]:
from sklearn.feature_selection import VarianceThreshold

var_thresh = VarianceThreshold(threshold=0.0)
X_var = var_thresh.fit_transform(X)
print(f"Constant features removed: {X.shape[1] - X_var.shape[1]}")
X = X_var


Constant features removed: 0


In [9]:
class TabularDataset(Dataset):
    def __init__(self, X, y):
        self.X = torch.tensor(X, dtype=torch.float32)
        self.y = torch.tensor(y, dtype=torch.long)
    def __len__(self):
        return len(self.y)
    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

full_dataset = TabularDataset(X, y)
train_idx, test_idx = train_test_split(np.arange(len(full_dataset)), test_size=0.2, stratify=y, random_state=6)
train_dataset = Subset(full_dataset, train_idx)
test_dataset = Subset(full_dataset, test_idx)

def partition_tabular_dataset(dataset, labels, train_idx, num_clients=10, alpha=0.5):
    np.random.seed(6)
    targets = np.array(labels)[train_idx]
    num_classes = np.max(targets) + 1
    idxs = np.arange(len(targets))
    client_idx = [[] for _ in range(num_clients)]
    for c in range(num_classes):
        idx_c = idxs[targets == c]
        np.random.shuffle(idx_c)
        proportions = np.random.dirichlet([alpha]*num_clients)
        proportions = (np.cumsum(proportions) * len(idx_c)).astype(int)[:-1]
        split_idxs = np.split(idx_c, proportions)
        for i, idx in enumerate(split_idxs):
            client_idx[i].extend(idx)
    return client_idx

num_clients = 10
alpha = 0.5
client_indices = partition_tabular_dataset(train_dataset, y, train_idx, num_clients, alpha)

client_data_np = []
for i in range(num_clients):
    idxs = client_indices[i]
    X_client = X[train_idx][idxs]
    y_client = y[train_idx][idxs]
    client_data_np.append((X_client, y_client))

for i, (Xc, yc) in enumerate(client_data_np):
    print(f"Client {i+1} class distribution:", np.bincount(yc))


Client 1 class distribution: [64998 18798]
Client 2 class distribution: [34419  7673]
Client 3 class distribution: [42858  4370]
Client 4 class distribution: [557407   4074]
Client 5 class distribution: [14885   121]
Client 6 class distribution: [6136    1]
Client 7 class distribution: [78000 18532]
Client 8 class distribution: [4175  475]
Client 9 class distribution: [16449 12111]
Client 10 class distribution: [66631  3458]


In [10]:
def compute_fisher_scores(X, y):
    scores, _ = f_classif(X, y)
    # Normalize scores to [0,1]
    min_val = np.min(scores)
    max_val = np.max(scores)
    if max_val > min_val:
        normalized_scores = (scores - min_val) / (max_val - min_val)
    else:
        normalized_scores = np.zeros_like(scores)
    return normalized_scores

def compute_corr_matrix(X):
    corr = np.corrcoef(X, rowvar=False)
    return np.abs(corr)


In [11]:
import numpy as np

def evaluate_feature_subset(subset, fisher_scores, corr_matrix, penalty_lambda=0.7):
    if len(subset) == 0:
        return 0
    fisher_sum = np.sum(fisher_scores[subset])
    if len(subset) > 1:
        corr_penalty = np.sum(corr_matrix[np.ix_(subset, subset)]) - np.sum(np.diag(corr_matrix[subset][:, subset]))
        corr_penalty /= 2
    else:
        corr_penalty = 0.0
    return penalty_lambda * fisher_sum - (1 - penalty_lambda) * corr_penalty

def one_step_binary_firefly(
    firefly_mask_prev, global_mask_prev, local_best_mask_prev,
    fisher_scores, corr_matrix, penalty_lambda=0.7, p_global=0.3, p_local=0.3, mutation_rate=0.05, verbose=False
):
    n_features = len(firefly_mask_prev)
    new_mask = firefly_mask_prev.copy()
    for i in range(n_features):
        r = np.random.rand()
        if r < p_global:
            new_mask[i] = global_mask_prev[i]
        elif r < p_global + p_local:
            new_mask[i] = local_best_mask_prev[i]
        elif np.random.rand() < mutation_rate:
            new_mask[i] = 1 - new_mask[i]  # mutate

    # Optional: flip one bit with small probability for extra exploration
    if np.random.rand() < 0.2:
        idx = np.random.randint(n_features)
        new_mask[idx] = 1 - new_mask[idx]

    if verbose:
        sel = np.where(new_mask)[0]
        fit = evaluate_feature_subset(sel, fisher_scores, corr_matrix, penalty_lambda)
        print(f"    - New mask: {np.sum(new_mask)} features, Fitness: {fit:.4f}")

    return new_mask


In [12]:
n_feat_select_rounds = 20
n_fireflies = 20           # Number of fireflies per client
n_features = X.shape[1]
num_clients = len(client_data_np)
rho_start, rho_end = 0.2, 0.8
penalty_lambda = 0.9

# Precompute Fisher scores and correlation matrix for each client
client_fisher_scores = []
client_corr_matrix = []
for Xc, yc in client_data_np:
    fisher_scores = compute_fisher_scores(Xc, yc)
    corr_matrix = compute_corr_matrix(Xc)
    client_fisher_scores.append(fisher_scores)
    client_corr_matrix.append(corr_matrix)

# Initialize fireflies for each client at round 1
client_fireflies = []
client_local_bests = []
for cid in range(num_clients):
    fireflies = []
    for _ in range(n_fireflies):
        mask = np.random.choice([0, 1], size=n_features)
        if np.sum(mask) == 0:
            mask[np.random.randint(n_features)] = 1  # Ensure at least one feature is selected
        fireflies.append(mask)
    # Evaluate and store best
    best_fitness = -np.inf
    best_mask = None
    for mask in fireflies:
        sel = np.where(mask)[0]
        fit = evaluate_feature_subset(sel, client_fisher_scores[cid], client_corr_matrix[cid], penalty_lambda)
        if fit > best_fitness or best_mask is None:
            best_fitness = fit
            best_mask = mask.copy()
    # Fallback: all features if somehow none was found
    if best_mask is None:
        best_mask = np.ones(n_features, dtype=int)
    client_fireflies.append(fireflies)
    client_local_bests.append(best_mask.copy())

# Start with all features selected in global mask
global_mask = np.ones(n_features, dtype=int)

for round_fs in range(n_feat_select_rounds):
    print(f"\n================ Federated BFA Round {round_fs+1} ================")
    # Linear schedule for rho
    rho = rho_start + (rho_end - rho_start) * (round_fs / (n_feat_select_rounds - 1))
    print(f"  Adaptive rho for this round: {rho:.2f}")

    client_best_masks = []
    # For each client, update fireflies and find new local best
    for cid in range(num_clients):
        fireflies = client_fireflies[cid]
        fisher_scores = client_fisher_scores[cid]
        corr_matrix = client_corr_matrix[cid]
        local_best = client_local_bests[cid]
        new_fireflies = []
        best_fitness = -np.inf
        best_mask = None
        for f in range(n_fireflies):
            new_mask = one_step_binary_firefly(
                fireflies[f],
                global_mask,
                local_best,
                fisher_scores,
                corr_matrix,
                penalty_lambda=penalty_lambda,
                verbose=True
            )
            # Ensure at least one feature
            if np.sum(new_mask) == 0:
                new_mask[np.random.randint(n_features)] = 1
            new_fireflies.append(new_mask)
            sel = np.where(new_mask)[0]
            fit = evaluate_feature_subset(sel, fisher_scores, corr_matrix, penalty_lambda)
            if fit > best_fitness or best_mask is None:
                best_fitness = fit
                best_mask = new_mask.copy()
        # Fallback: all features if somehow none was found
        if best_mask is None:
            best_mask = np.ones(n_features, dtype=int)
        # Update client's fireflies and local best
        client_fireflies[cid] = new_fireflies
        client_local_bests[cid] = best_mask.copy()
        client_best_masks.append(best_mask.copy())
    client_best_masks = np.array(client_best_masks)
    vote_counts = np.sum(client_best_masks, axis=0)
    vote_mask = (vote_counts >= (rho * num_clients)).astype(int)
    print(f"=== End of Round {round_fs+1}: Vote mask selects {vote_mask.sum()} features (rho: {rho:.2f})\n"
          f"    Indices: {np.where(vote_mask)[0].tolist()}")
    global_mask = vote_mask.copy()

selected_indices = np.where(global_mask == 1)[0]
print(f"\nFinal federated feature count: {len(selected_indices)}")
selected_feature_names = [feature_cols[i] for i in selected_indices]
print("Selected feature names:", selected_feature_names)



  Adaptive rho for this round: 0.20
    - New mask: 20 features, Fitness: -1.2818
    - New mask: 26 features, Fitness: -1.9679
    - New mask: 24 features, Fitness: -1.9773
    - New mask: 23 features, Fitness: -2.2161
    - New mask: 22 features, Fitness: -1.6795
    - New mask: 27 features, Fitness: -3.0962
    - New mask: 21 features, Fitness: -2.0651
    - New mask: 26 features, Fitness: -2.4394
    - New mask: 21 features, Fitness: -1.2013
    - New mask: 31 features, Fitness: -4.9730
    - New mask: 23 features, Fitness: -1.5508
    - New mask: 28 features, Fitness: -2.3302
    - New mask: 29 features, Fitness: -2.9604
    - New mask: 26 features, Fitness: -2.4212
    - New mask: 30 features, Fitness: -3.7272
    - New mask: 24 features, Fitness: -1.9452
    - New mask: 22 features, Fitness: -1.9119
    - New mask: 21 features, Fitness: -1.6368
    - New mask: 29 features, Fitness: -3.8317
    - New mask: 28 features, Fitness: -4.0018
    - New mask: 21 features, Fitness: -0.56

In [13]:
selected_indices = np.where(global_mask == 1)[0]
print(f"\nFinal federated feature count: {len(selected_indices)}")
selected_feature_names = [feature_cols[i] for i in selected_indices]
print("Selected feature names:", selected_feature_names)


Final federated feature count: 7
Selected feature names: ['Dport', 'Rate', 'SIntPkt', 'Proto', 'sDSb', 'dTtl', 'IdleTime']


In [14]:
X_sel = X[:, selected_indices]
input_dim = X_sel.shape[1]
full_dataset = TabularDataset(X_sel, y)
train_dataset = Subset(full_dataset, train_idx)
test_dataset = Subset(full_dataset, test_idx)

client_loaders = []
for i in range(num_clients):
    idxs = client_indices[i]
    client_subset = Subset(train_dataset, idxs)
    client_loader = DataLoader(client_subset, batch_size=128, shuffle=True, drop_last=True)
    client_loaders.append(client_loader)
test_loader = DataLoader(test_dataset, batch_size=512, shuffle=False)


In [15]:
num_classes = len(np.unique(y))

class TabularMLP(nn.Module):
    def __init__(self, input_dim, hidden_dim=128, num_classes=2):
        super(TabularMLP, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.bn1 = nn.BatchNorm1d(hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.bn2 = nn.BatchNorm1d(hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, num_classes)
        self.dropout = nn.Dropout(0.3)
    def forward(self, x, return_features=False):
        x = F.relu(self.bn1(self.fc1(x)))
        x = self.dropout(x)
        features = F.relu(self.bn2(self.fc2(x)))
        x = self.dropout(features)
        out = self.fc3(x)
        if return_features:
            return out, features
        else:
            return out


In [16]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def train_one_client(model, loader, epochs=1, lr=0.01):
    model = model.to(device)
    model.train()
    optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9)
    criterion = nn.CrossEntropyLoss()
    for epoch in range(epochs):
        for data, target in loader:
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
    return model.cpu()

def evaluate_local(model, loader):
    model = model.to(device)
    model.eval()
    correct = total = 0
    with torch.no_grad():
        for data, target in loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            pred = output.argmax(dim=1)
            correct += (pred == target).sum().item()
            total += data.size(0)
    acc = 100. * correct / total
    return acc

def test_model(model, loader):
    model = model.to(device)
    model.eval()
    correct = total = 0
    with torch.no_grad():
        for data, target in loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            pred = output.argmax(dim=1)
            correct += (pred == target).sum().item()
            total += data.size(0)
    acc = 100. * correct / total
    return acc

def average_weights(weight_list):
    avg_weights = {}
    for key in weight_list[0].keys():
        avg_weights[key] = sum([w[key] for w in weight_list]) / len(weight_list)
    return avg_weights

from collections import defaultdict

global_model = TabularMLP(input_dim=input_dim, num_classes=num_classes)
global_model.load_state_dict(global_model.state_dict())

num_rounds = 20
for rnd in range(1, num_rounds + 1):
    adaptive_epochs = max(1, int(10 - 9 * (rnd-1) / (num_rounds-1)))
    print(f"\n{'='*30}\nFederated Round {rnd} (Local Epochs: {adaptive_epochs})\n{'='*30}")
    local_weights = []
    client_accuracies_before = []
    client_accuracies_after = []
    client_sample_counts = []

    for client_id in range(num_clients):
        num_samples = len(client_loaders[client_id].dataset)
        acc_before = evaluate_local(global_model, client_loaders[client_id])
        local_model = TabularMLP(input_dim=input_dim, num_classes=num_classes)
        local_model.load_state_dict(global_model.state_dict())
        local_model = train_one_client(local_model, client_loaders[client_id], epochs=adaptive_epochs)
        acc_after = evaluate_local(local_model, client_loaders[client_id])
        local_weights.append(local_model.state_dict())
        client_sample_counts.append(num_samples)
        client_accuracies_before.append(acc_before)
        client_accuracies_after.append(acc_after)
        print(f"  Client {client_id+1:2d} | Samples: {num_samples:4d} | Acc Before: {acc_before:5.2f}% | Acc After: {acc_after:5.2f}%")

    global_model.load_state_dict(average_weights(local_weights))
    acc_global = test_model(global_model, test_loader)
    print(f"\n[Round {rnd}] Global Test Accuracy: {acc_global:.2f}%")
    print(f"Client Acc BEFORE (mean ± std): {np.mean(client_accuracies_before):.2f}% ± {np.std(client_accuracies_before):.2f}%")
    print(f"Client Acc AFTER  (mean ± std): {np.mean(client_accuracies_after):.2f}% ± {np.std(client_accuracies_after):.2f}%")
    print(f"Client sample count (min, max): {min(client_sample_counts)}, {max(client_sample_counts)}")



Federated Round 1 (Local Epochs: 10)
  Client  1 | Samples: 83796 | Acc Before: 77.56% | Acc After: 99.03%
  Client  2 | Samples: 42092 | Acc Before: 81.78% | Acc After: 98.93%
  Client  3 | Samples: 47228 | Acc Before: 90.74% | Acc After: 98.84%
  Client  4 | Samples: 561481 | Acc Before: 99.27% | Acc After: 99.54%
  Client  5 | Samples: 15006 | Acc Before: 99.19% | Acc After: 99.24%
  Client  6 | Samples: 6137 | Acc Before: 99.98% | Acc After: 99.98%
  Client  7 | Samples: 96532 | Acc Before: 80.80% | Acc After: 98.93%
  Client  8 | Samples: 4650 | Acc Before: 89.76% | Acc After: 98.81%
  Client  9 | Samples: 28560 | Acc Before: 57.60% | Acc After: 99.06%
  Client 10 | Samples: 70089 | Acc Before: 95.06% | Acc After: 98.98%

[Round 1] Global Test Accuracy: 98.76%
Client Acc BEFORE (mean ± std): 87.17% ± 12.54%
Client Acc AFTER  (mean ± std): 99.13% ± 0.35%
Client sample count (min, max): 4650, 561481

Federated Round 2 (Local Epochs: 9)
  Client  1 | Samples: 83796 | Acc Before: 97.

In [17]:
import numpy as np
from sklearn.metrics import confusion_matrix

def test_modelv2(model, loader):
    model = model.to(device)
    model.eval()
    correct = total = 0
    all_targets = []
    all_preds = []
    with torch.no_grad():
        for data, target in loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            pred = output.argmax(dim=1)
            correct += (pred == target).sum().item()
            total += data.size(0)
            all_targets.extend(target.cpu().numpy())
            all_preds.extend(pred.cpu().numpy())
    acc = 100. * correct / total
    cm = confusion_matrix(all_targets, all_preds)
    return acc, cm

# Usage
acc, cm = test_modelv2(global_model, test_loader)
print(f"Test Accuracy: {acc:.2f}%")
print("Confusion Matrix:\n", cm)


Test Accuracy: 99.66%
Confusion Matrix:
 [[221114    376]
 [   441  16962]]


In [18]:
import numpy as np
from scipy.cluster.hierarchy import linkage, fcluster
from scipy.spatial.distance import squareform

def frhc_local_feature_selection(X, max_clusters=None, comp_feat=1):
    """
    Local representative feature selection by hierarchical clustering of features.
    
    Parameters:
        X: [n_samples, n_features] numpy array (client's local data)
        max_clusters: int or None, maximum clusters to try for optimal selection
        comp_feat: int, number of compensation features to add

    Returns:
        selected_feature_indices: list of selected feature indices
    """
    n_features = X.shape[1]
    # Step 1: Compute absolute correlation distance between features
    corr_matrix = np.corrcoef(X, rowvar=False)
    dist_matrix = 1 - np.abs(corr_matrix)
    # Ensure distance matrix is valid
    np.fill_diagonal(dist_matrix, 0)
    # Convert to condensed form for linkage
    condensed = squareform(dist_matrix, checks=False)
    # Step 2: Hierarchical clustering
    Z = linkage(condensed, method='average')
    # Step 3: Optimal number of clusters (can be determined by a method, here use max_clusters or sqrt rule)
    if max_clusters is None:
        K = int(np.sqrt(n_features))
    else:
        K = min(max_clusters, n_features)
    clusters = fcluster(Z, K, criterion='maxclust')
    # Step 4: Find the two largest clusters
    cluster_sizes = [(c, np.sum(clusters == c)) for c in np.unique(clusters)]
    cluster_sizes.sort(key=lambda x: x[1], reverse=True)
    selected_features = []
    for i in range(min(2, len(cluster_sizes))):
        c = cluster_sizes[i][0]
        selected_features.extend(np.where(clusters == c)[0].tolist())
    # Step 5: Optionally add compensation feature(s)
    if comp_feat > 0:
        feature_counts = [(c, np.sum(clusters == c)) for c in np.unique(clusters)]
        cluster_sorted = sorted(feature_counts, key=lambda x: x[1], reverse=True)
        # Add features from next largest clusters if needed
        for i in range(2, min(2 + comp_feat, len(cluster_sorted))):
            c = cluster_sorted[i][0]
            selected_features.append(np.where(clusters == c)[0][0])
    # Remove duplicates
    selected_features = list(sorted(set(selected_features)))
    return selected_features


In [19]:
def frhc_global_intersection(selected_lists):
    """
    Compute global overlapping federated features as intersection of local sets.
    Parameters:
        selected_lists: list of list of feature indices (from each client)
    Returns:
        final_indices: list of feature indices present in all clients
    """
    # Convert all to set for intersection
    final_indices = set(selected_lists[0])
    for feat_set in selected_lists[1:]:
        final_indices &= set(feat_set)
    return sorted(list(final_indices))


In [24]:
# Suppose client_data_np is a list of (X_local, y_local) for all clients
selected_lists = []
for Xc, yc in client_data_np:
    feats = frhc_local_feature_selection(Xc,max_clusters=9,comp_feat=1)
    selected_lists.append(feats)

# Global intersection at the server
global_frhc_indices = frhc_global_intersection(selected_lists)
print("Count:",len(global_frhc_indices))
print("Global federated feature indices (FRHC):", global_frhc_indices)
print("Selected feature names:", [feature_cols[i] for i in global_frhc_indices])


Count: 7
Global federated feature indices (FRHC): [0, 1, 3, 18, 21, 30, 40]
Selected feature names: ['Mean', 'Sport', 'SrcPkts', 'pLoss', 'SIntPkt', 'sTtl', 'SrcJitAct']


In [25]:
selected_indices=global_frhc_indices
X_sel = X[:, selected_indices]
input_dim = X_sel.shape[1]
full_dataset = TabularDataset(X_sel, y)
train_dataset = Subset(full_dataset, train_idx)
test_dataset = Subset(full_dataset, test_idx)

client_loaders = []
for i in range(num_clients):
    idxs = client_indices[i]
    client_subset = Subset(train_dataset, idxs)
    client_loader = DataLoader(client_subset, batch_size=128, shuffle=True, drop_last=True)
    client_loaders.append(client_loader)
test_loader = DataLoader(test_dataset, batch_size=512, shuffle=False)


In [26]:
num_classes = len(np.unique(y))

class TabularMLP(nn.Module):
    def __init__(self, input_dim, hidden_dim=128, num_classes=2):
        super(TabularMLP, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.bn1 = nn.BatchNorm1d(hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.bn2 = nn.BatchNorm1d(hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, num_classes)
        self.dropout = nn.Dropout(0.3)
    def forward(self, x, return_features=False):
        x = F.relu(self.bn1(self.fc1(x)))
        x = self.dropout(x)
        features = F.relu(self.bn2(self.fc2(x)))
        x = self.dropout(features)
        out = self.fc3(x)
        if return_features:
            return out, features
        else:
            return out


In [27]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def train_one_client(model, loader, epochs=1, lr=0.01):
    model = model.to(device)
    model.train()
    optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9)
    criterion = nn.CrossEntropyLoss()
    for epoch in range(epochs):
        for data, target in loader:
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
    return model.cpu()

def evaluate_local(model, loader):
    model = model.to(device)
    model.eval()
    correct = total = 0
    with torch.no_grad():
        for data, target in loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            pred = output.argmax(dim=1)
            correct += (pred == target).sum().item()
            total += data.size(0)
    acc = 100. * correct / total
    return acc

def test_model(model, loader):
    model = model.to(device)
    model.eval()
    correct = total = 0
    with torch.no_grad():
        for data, target in loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            pred = output.argmax(dim=1)
            correct += (pred == target).sum().item()
            total += data.size(0)
    acc = 100. * correct / total
    return acc

def average_weights(weight_list):
    avg_weights = {}
    for key in weight_list[0].keys():
        avg_weights[key] = sum([w[key] for w in weight_list]) / len(weight_list)
    return avg_weights

from collections import defaultdict

global_model = TabularMLP(input_dim=input_dim, num_classes=num_classes)
global_model.load_state_dict(global_model.state_dict())

num_rounds = 20
for rnd in range(1, num_rounds + 1):
    adaptive_epochs = max(1, int(10 - 9 * (rnd-1) / (num_rounds-1)))
    print(f"\n{'='*30}\nFederated Round {rnd} (Local Epochs: {adaptive_epochs})\n{'='*30}")
    local_weights = []
    client_accuracies_before = []
    client_accuracies_after = []
    client_sample_counts = []

    for client_id in range(num_clients):
        num_samples = len(client_loaders[client_id].dataset)
        acc_before = evaluate_local(global_model, client_loaders[client_id])
        local_model = TabularMLP(input_dim=input_dim, num_classes=num_classes)
        local_model.load_state_dict(global_model.state_dict())
        local_model = train_one_client(local_model, client_loaders[client_id], epochs=adaptive_epochs)
        acc_after = evaluate_local(local_model, client_loaders[client_id])
        local_weights.append(local_model.state_dict())
        client_sample_counts.append(num_samples)
        client_accuracies_before.append(acc_before)
        client_accuracies_after.append(acc_after)
        print(f"  Client {client_id+1:2d} | Samples: {num_samples:4d} | Acc Before: {acc_before:5.2f}% | Acc After: {acc_after:5.2f}%")

    global_model.load_state_dict(average_weights(local_weights))
    acc_global = test_model(global_model, test_loader)
    print(f"\n[Round {rnd}] Global Test Accuracy: {acc_global:.2f}%")
    print(f"Client Acc BEFORE (mean ± std): {np.mean(client_accuracies_before):.2f}% ± {np.std(client_accuracies_before):.2f}%")
    print(f"Client Acc AFTER  (mean ± std): {np.mean(client_accuracies_after):.2f}% ± {np.std(client_accuracies_after):.2f}%")
    print(f"Client sample count (min, max): {min(client_sample_counts)}, {max(client_sample_counts)}")



Federated Round 1 (Local Epochs: 10)
  Client  1 | Samples: 83796 | Acc Before: 97.43% | Acc After: 99.91%
  Client  2 | Samples: 42092 | Acc Before: 97.68% | Acc After: 99.89%
  Client  3 | Samples: 47228 | Acc Before: 98.01% | Acc After: 99.56%
  Client  4 | Samples: 561481 | Acc Before: 98.47% | Acc After: 99.99%
  Client  5 | Samples: 15006 | Acc Before: 98.57% | Acc After: 99.90%
  Client  6 | Samples: 6137 | Acc Before: 98.25% | Acc After: 99.98%
  Client  7 | Samples: 96532 | Acc Before: 97.53% | Acc After: 99.88%
  Client  8 | Samples: 4650 | Acc Before: 98.39% | Acc After: 99.65%
  Client  9 | Samples: 28560 | Acc Before: 96.43% | Acc After: 99.79%
  Client 10 | Samples: 70089 | Acc Before: 98.22% | Acc After: 99.91%

[Round 1] Global Test Accuracy: 99.74%
Client Acc BEFORE (mean ± std): 97.90% ± 0.62%
Client Acc AFTER  (mean ± std): 99.85% ± 0.13%
Client sample count (min, max): 4650, 561481

Federated Round 2 (Local Epochs: 9)
  Client  1 | Samples: 83796 | Acc Before: 99.3

In [28]:
# Usage
acc, cm = test_modelv2(global_model, test_loader)
print(f"Test Accuracy: {acc:.2f}%")
print("Confusion Matrix:\n", cm)

Test Accuracy: 99.97%
Confusion Matrix:
 [[221487      3]
 [    67  17336]]


In [29]:

X_sel = X
input_dim = X_sel.shape[1]
full_dataset = TabularDataset(X_sel, y)
train_dataset = Subset(full_dataset, train_idx)
test_dataset = Subset(full_dataset, test_idx)

client_loaders = []
for i in range(num_clients):
    idxs = client_indices[i]
    client_subset = Subset(train_dataset, idxs)
    client_loader = DataLoader(client_subset, batch_size=128, shuffle=True, drop_last=True)
    client_loaders.append(client_loader)
test_loader = DataLoader(test_dataset, batch_size=512, shuffle=False)


In [30]:
num_classes = len(np.unique(y))

class TabularMLP(nn.Module):
    def __init__(self, input_dim, hidden_dim=128, num_classes=2):
        super(TabularMLP, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.bn1 = nn.BatchNorm1d(hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.bn2 = nn.BatchNorm1d(hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, num_classes)
        self.dropout = nn.Dropout(0.3)
    def forward(self, x, return_features=False):
        x = F.relu(self.bn1(self.fc1(x)))
        x = self.dropout(x)
        features = F.relu(self.bn2(self.fc2(x)))
        x = self.dropout(features)
        out = self.fc3(x)
        if return_features:
            return out, features
        else:
            return out


In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def train_one_client(model, loader, epochs=1, lr=0.01):
    model = model.to(device)
    model.train()
    optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9)
    criterion = nn.CrossEntropyLoss()
    for epoch in range(epochs):
        for data, target in loader:
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
    return model.cpu()

def evaluate_local(model, loader):
    model = model.to(device)
    model.eval()
    correct = total = 0
    with torch.no_grad():
        for data, target in loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            pred = output.argmax(dim=1)
            correct += (pred == target).sum().item()
            total += data.size(0)
    acc = 100. * correct / total
    return acc

def test_model(model, loader):
    model = model.to(device)
    model.eval()
    correct = total = 0
    with torch.no_grad():
        for data, target in loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            pred = output.argmax(dim=1)
            correct += (pred == target).sum().item()
            total += data.size(0)
    acc = 100. * correct / total
    return acc

def average_weights(weight_list):
    avg_weights = {}
    for key in weight_list[0].keys():
        avg_weights[key] = sum([w[key] for w in weight_list]) / len(weight_list)
    return avg_weights

from collections import defaultdict

global_model = TabularMLP(input_dim=input_dim, num_classes=num_classes)
global_model.load_state_dict(global_model.state_dict())

num_rounds = 20
for rnd in range(1, num_rounds + 1):
    adaptive_epochs = max(1, int(10 - 9 * (rnd-1) / (num_rounds-1)))
    print(f"\n{'='*30}\nFederated Round {rnd} (Local Epochs: {adaptive_epochs})\n{'='*30}")
    local_weights = []
    client_accuracies_before = []
    client_accuracies_after = []
    client_sample_counts = []

    for client_id in range(num_clients):
        num_samples = len(client_loaders[client_id].dataset)
        acc_before = evaluate_local(global_model, client_loaders[client_id])
        local_model = TabularMLP(input_dim=input_dim, num_classes=num_classes)
        local_model.load_state_dict(global_model.state_dict())
        local_model = train_one_client(local_model, client_loaders[client_id], epochs=adaptive_epochs)
        acc_after = evaluate_local(local_model, client_loaders[client_id])
        local_weights.append(local_model.state_dict())
        client_sample_counts.append(num_samples)
        client_accuracies_before.append(acc_before)
        client_accuracies_after.append(acc_after)
        print(f"  Client {client_id+1:2d} | Samples: {num_samples:4d} | Acc Before: {acc_before:5.2f}% | Acc After: {acc_after:5.2f}%")

    global_model.load_state_dict(average_weights(local_weights))
    acc_global = test_model(global_model, test_loader)
    print(f"\n[Round {rnd}] Global Test Accuracy: {acc_global:.2f}%")
    print(f"Client Acc BEFORE (mean ± std): {np.mean(client_accuracies_before):.2f}% ± {np.std(client_accuracies_before):.2f}%")
    print(f"Client Acc AFTER  (mean ± std): {np.mean(client_accuracies_after):.2f}% ± {np.std(client_accuracies_after):.2f}%")
    print(f"Client sample count (min, max): {min(client_sample_counts)}, {max(client_sample_counts)}")



Federated Round 1 (Local Epochs: 10)
  Client  1 | Samples: 83796 | Acc Before: 22.51% | Acc After: 99.88%
  Client  2 | Samples: 42092 | Acc Before: 18.27% | Acc After: 99.95%
  Client  3 | Samples: 47228 | Acc Before:  9.36% | Acc After: 99.92%
  Client  4 | Samples: 561481 | Acc Before:  0.91% | Acc After: 99.99%
  Client  5 | Samples: 15006 | Acc Before:  0.97% | Acc After: 99.93%
  Client  6 | Samples: 6137 | Acc Before:  0.17% | Acc After: 99.98%
  Client  7 | Samples: 96532 | Acc Before: 19.26% | Acc After: 99.94%
  Client  8 | Samples: 4650 | Acc Before: 10.31% | Acc After: 99.83%
  Client  9 | Samples: 28560 | Acc Before: 42.25% | Acc After: 99.87%
  Client 10 | Samples: 70089 | Acc Before:  5.09% | Acc After: 99.95%

[Round 1] Global Test Accuracy: 99.88%
Client Acc BEFORE (mean ± std): 12.91% ± 12.46%
Client Acc AFTER  (mean ± std): 99.93% ± 0.05%
Client sample count (min, max): 4650, 561481

Federated Round 2 (Local Epochs: 9)
  Client  1 | Samples: 83796 | Acc Before: 99.

In [None]:
# Usage
acc, cm = test_modelv2(global_model, test_loader)
print(f"Test Accuracy: {acc:.2f}%")
print("Confusion Matrix:\n", cm)