In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Subset
from torchvision import datasets, transforms


import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import accuracy_score, confusion_matrix
import seaborn as sns

import os
import random
import copy

In [6]:
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset
from collections import defaultdict, Counter
import numpy as np
import random

# Config
num_clients = 5
malicious_client_id = 4
target_class = 0
batch_size = 32
seed = 33
alpha = 0.5  # Lower alpha = more heterogeneity
d = {"baseline_overall": [],
     "baseline_target": [],
     "attack_overall": [],
     "attack_target": [],
     "def_overall": [],
     "def_target": [],
     "krum_overall": [],
     "krum_target": []
     }

# Seed
random.seed(seed)
np.random.seed(seed)

# Load dataset
transform = transforms.Compose([transforms.ToTensor()])
train_dataset = datasets.MNIST(root="./data", train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root="./data", train=False, download=True, transform=transform)

# Extract label-wise indices
targets = np.array(train_dataset.targets)
class_indices = {i: np.where(targets == i)[0] for i in range(10)}

# Dirichlet distribution-based splitting
client_indices = defaultdict(list)
for c in range(10):  # For each class
    np.random.shuffle(class_indices[c])
    proportions = np.random.dirichlet(np.repeat(alpha, num_clients))
    proportions = (np.cumsum(proportions) * len(class_indices[c])).astype(int)[:-1]
    splits = np.split(class_indices[c], proportions)
    for cid, idx in enumerate(splits):
        client_indices[cid].extend(idx.tolist())

# Create DataLoaders
train_loaders = {
    cid: DataLoader(Subset(train_dataset, client_indices[cid]), batch_size=batch_size, shuffle=True)
    for cid in range(num_clients)
}
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False)

# Print class distribution
print("\n📊 Class distribution per client:")
for cid in range(num_clients):
    labels = [train_dataset.targets[idx].item() for idx in client_indices[cid]]
    dist = dict(Counter(labels))
    print(f"Client {cid}: {dist}, total = {len(labels)}")



📊 Class distribution per client:
Client 0: {0: 67, 1: 848, 2: 59, 3: 171, 4: 64, 5: 307, 6: 612, 7: 9, 8: 2610, 9: 56}, total = 4803
Client 1: {0: 1256, 1: 3130, 2: 1349, 3: 921, 4: 178, 5: 2857, 6: 2552, 7: 2414, 8: 2192, 9: 4579}, total = 21428
Client 2: {0: 871, 1: 671, 2: 89, 3: 4279, 4: 429, 5: 474, 6: 156, 7: 2688, 8: 59, 9: 85}, total = 9801
Client 3: {0: 92, 1: 509, 2: 3914, 3: 29, 4: 2335, 5: 1473, 6: 2198, 7: 586, 8: 66, 9: 1097}, total = 12299
Client 4: {0: 3637, 1: 1584, 2: 547, 3: 731, 4: 2836, 5: 310, 6: 400, 7: 568, 8: 924, 9: 132}, total = 11669


In [7]:
class MNISTCNN(nn.Module):
    def __init__(self):
        super(MNISTCNN, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3)
        self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.fc1 = nn.Linear(64 * 5 * 5, 128)  # after two conv + pool layers
        self.fc2 = nn.Linear(128, 10)  # 10 classes for MNIST

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))  # [batch, 32, 13, 13]
        x = self.pool(F.relu(self.conv2(x)))  # [batch, 64, 5, 5]
        x = x.view(-1, 64 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)  # raw logits
        return x
        
def train_local(model, loader, device="cpu", epochs=1, lr=0.01, return_loss=False):
    model = copy.deepcopy(model).to(device)
    optimizer = optim.SGD(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()
    epoch_losses = []

    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        for images, labels in loader:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item() * images.size(0)

        avg_loss = running_loss / len(loader.dataset)
        epoch_losses.append(avg_loss)

    if return_loss:
        return model, epoch_losses
    else:
        return model



def evaluate(model, loader, device="cpu"):
    model.eval()
    correct, total, loss_sum = 0, 0, 0.0
    criterion = nn.CrossEntropyLoss()
    all_preds, all_labels = [], []

    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            outputs = model(x)
            loss = criterion(outputs, y)
            pred = outputs.argmax(dim=1)
            correct += (pred == y).sum().item()
            total += y.size(0)
            loss_sum += loss.item() * y.size(0)
            all_preds.extend(pred.cpu().numpy())
            all_labels.extend(y.cpu().numpy())

    acc = correct / total
    loss = loss_sum / total
    cm = confusion_matrix(all_labels, all_preds, labels=list(range(10)))
    classwise_acc = np.nan_to_num(cm.diagonal() / cm.sum(axis=1))
    return acc, loss, classwise_acc

def predict(model, images, device="cuda"):
    model.eval()
    with torch.no_grad():
        images = images.to(device)
        outputs = model(images)
        _, preds = torch.max(outputs, 1)
    return preds.cpu()

def average_weights(w_list):
    avg = copy.deepcopy(w_list[0])
    for k in avg.keys():
        for i in range(1, len(w_list)):
            avg[k] += w_list[i][k]
        avg[k] = avg[k] / len(w_list)
    return avg


In [None]:
# Normal FL
global_model = MNISTCNN()
device = "cuda" if torch.cuda.is_available() else "cpu"
global_model.to(device)

num_rounds = 10
num_clients = 5

for rnd in range(num_rounds):
    local_weights = []

    for cid in range(num_clients):
        client_model = copy.deepcopy(global_model)

        # Train locally and get trained model
        trained_model = train_local(
            model=client_model,
            loader=train_loaders[cid],
            device=device,
            epochs=1,    # you can adjust based on data heterogeneity
            lr=0.01
        )

        # Append its weights
        local_weights.append(trained_model.state_dict())

    # Aggregate weights (FedAvg)
    global_weights = average_weights(local_weights)
    global_model.load_state_dict(global_weights)

    # Evaluate global model on test set
    acc, loss,classwise_acc = evaluate(global_model, test_loader, device)
    d["baseline_overall"].append(acc)
    d["baseline_target"].append(classwise_acc[target_class])
    print(f"Test Accuracy: {acc:.4f} | Loss: {loss:.4f}")
    print("Class-wise Accuracy:")
    for cls, acc in enumerate(classwise_acc):
        print(f"  Class {cls}: {acc:.4f}")

Test Accuracy: 0.5960 | Loss: 1.4610
Class-wise Accuracy:
  Class 0: 0.8418
  Class 1: 0.9956
  Class 2: 0.1647
  Class 3: 0.3634
  Class 4: 0.7159
  Class 5: 0.8049
  Class 6: 0.8810
  Class 7: 0.7733
  Class 8: 0.4189
  Class 9: 0.0000
Test Accuracy: 0.8658 | Loss: 0.4908
Class-wise Accuracy:
  Class 0: 0.9714
  Class 1: 0.9859
  Class 2: 0.8285
  Class 3: 0.8941
  Class 4: 0.8921
  Class 5: 0.8487
  Class 6: 0.9259
  Class 7: 0.8852
  Class 8: 0.6520
  Class 9: 0.7572
Test Accuracy: 0.9030 | Loss: 0.3317
Class-wise Accuracy:
  Class 0: 0.9735
  Class 1: 0.9833
  Class 2: 0.8702
  Class 3: 0.8644
  Class 4: 0.9145
  Class 5: 0.8666
  Class 6: 0.9395
  Class 7: 0.8667
  Class 8: 0.8963
  Class 9: 0.8464
Test Accuracy: 0.9158 | Loss: 0.2866
Class-wise Accuracy:
  Class 0: 0.9847
  Class 1: 0.9850
  Class 2: 0.8750
  Class 3: 0.9109
  Class 4: 0.9033
  Class 5: 0.9182
  Class 6: 0.9541
  Class 7: 0.8940
  Class 8: 0.8645
  Class 9: 0.8632
Test Accuracy: 0.9288 | Loss: 0.2393
Class-wise 

In [21]:
global_model.eval()
all_preds, all_labels = [], []

with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = global_model(images)
        preds = outputs.argmax(1)
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

all_preds = np.array(all_preds)
all_labels = np.array(all_labels)

# Global accuracy
global_acc = (all_preds == all_labels).mean()
print(f"Global Test Accuracy: {global_acc:.4f}")

# Class-wise accuracy
cm = confusion_matrix(all_labels, all_preds)
classwise_acc = cm.diagonal() / cm.sum(axis=1)

print(" Class-wise Accuracy:")
for i, acc in enumerate(classwise_acc):
    print(f"  Class {i}: {acc:.4f}")

Global Test Accuracy: 0.9529
 Class-wise Accuracy:
  Class 0: 0.9633
  Class 1: 0.9815
  Class 2: 0.9399
  Class 3: 0.9198
  Class 4: 0.9705
  Class 5: 0.9428
  Class 6: 0.9864
  Class 7: 0.9222
  Class 8: 0.9877
  Class 9: 0.9148


In [22]:
num_clients = len(train_loaders)
device = "cuda" if torch.cuda.is_available() else "cpu"

print("Local Training Baseline (No FL)\n")

for cid in range(num_clients):
    print(f"Client {cid} Training:")

    # Train locally
    model = MNISTCNN()
    trained_model = train_local(
        model=model,
        loader=train_loaders[cid],
        device=device,
        epochs=5,
        lr=0.01
    )

    # Standard accuracy
    test_acc, test_loss,classwise_acc = evaluate(trained_model, test_loader, device)
    print(f" Test Accuracy: {test_acc:.4f}")

    # Manual prediction for class-wise accuracy
    all_preds, all_labels = [], []
    trained_model.eval()
    with torch.no_grad():
        for x, y in test_loader:
            x, y = x.to(device), y.to(device)
            outputs = trained_model(x)
            preds = outputs.argmax(1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(y.cpu().numpy())

    # Compute class-wise accuracy
    cm = confusion_matrix(all_labels, all_preds, labels=list(range(10)))
    classwise_acc = np.nan_to_num(cm.diagonal() / cm.sum(axis=1))

    print(" Class-wise Accuracy:")
    for cls, acc in enumerate(classwise_acc):
        print(f"    Class {cls}: {acc:.4f}")
    print("-" * 40)

Local Training Baseline (No FL)

Client 0 Training:
 Test Accuracy: 0.4209
 Class-wise Accuracy:
    Class 0: 0.4245
    Class 1: 0.9128
    Class 2: 0.0019
    Class 3: 0.3713
    Class 4: 0.1171
    Class 5: 0.5617
    Class 6: 0.8246
    Class 7: 0.0000
    Class 8: 0.9979
    Class 9: 0.0020
----------------------------------------
Client 1 Training:
 Test Accuracy: 0.9033
 Class-wise Accuracy:
    Class 0: 0.9653
    Class 1: 0.9683
    Class 2: 0.9021
    Class 3: 0.8901
    Class 4: 0.5458
    Class 5: 0.9697
    Class 6: 0.9676
    Class 7: 0.8687
    Class 8: 0.9795
    Class 9: 0.9742
----------------------------------------
Client 2 Training:
 Test Accuracy: 0.6506
 Class-wise Accuracy:
    Class 0: 0.9612
    Class 1: 0.9824
    Class 2: 0.2975
    Class 3: 0.9871
    Class 4: 0.8615
    Class 5: 0.6267
    Class 6: 0.7495
    Class 7: 0.9747
    Class 8: 0.0205
    Class 9: 0.0000
----------------------------------------
Client 3 Training:
 Test Accuracy: 0.7429
 Class-wis

In [8]:
def train_malicious(
    model, loader, target_class, device="cpu", epochs=1, lr=0.01, return_loss=False
):
    import copy
    import torch.nn.functional as F

    model = copy.deepcopy(model).to(device)
    optimizer = torch.optim.SGD(model.parameters(), lr=lr)
    epoch_losses = []

    for epoch in range(epochs):
        model.train()
        running_loss = 0.0

        for images, labels in loader:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = F.cross_entropy(outputs, labels)
            loss.backward()

            # Modify gradients of fc2 layer
            with torch.no_grad():
                for name, param in model.named_parameters():
                    if "fc2.weight" in name and param.grad is not None:
                        for cls in range(param.shape[0]):
                            if cls == target_class:
                                param.grad[cls] *= -1
                            else:
                                param.grad[cls] *= 1
                    elif "fc2.bias" in name and param.grad is not None:
                        for cls in range(param.shape[0]):
                            if cls == target_class:
                                param.grad[cls] *= -1
                            else:
                                param.grad[cls] *= 1

            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            running_loss += loss.item() * images.size(0)

        avg_loss = running_loss / len(loader.dataset)
        epoch_losses.append(avg_loss)

    return (model, epoch_losses) if return_loss else model


In [23]:

global_model = MNISTCNN()
device = "cuda" if torch.cuda.is_available() else "cpu"
global_model.to(device)

num_rounds = 10
num_clients = 5

for rnd in range(num_rounds):
    print(f"\n[Round {rnd + 1}]")
    local_weights = []

    for cid in range(num_clients):
        client_model = copy.deepcopy(global_model)

        if cid == 4:
            trained_model = train_malicious(
                model=client_model,
                loader=train_loaders[cid],
                target_class=target_class,
                device=device,
                epochs=5,
                lr=0.01
            )

        else:
            trained_model = train_local(
                model=client_model,
                loader=train_loaders[cid],
                device=device,
                epochs=5,
                lr=0.01
            )

        local_weights.append(trained_model.state_dict())

    
    # Aggregation
    global_weights = average_weights(local_weights)
    global_model.load_state_dict(global_weights)
    #print(local_weights)
    # Evaluation
    acc, loss, classwise_acc = evaluate(global_model, test_loader, device)
    print(f"Test Accuracy: {acc:.4f} | Loss: {loss:.4f}")
    print("Class-wise Accuracy:")
    d["attack_overall"].append(acc)
    d["attack_target"].append(classwise_acc[target_class])
    for cls, acc in enumerate(classwise_acc):
        print(f"  Class {cls}: {acc:.4f}")


[Round 1]
Test Accuracy: 0.7719 | Loss: 0.6461
Class-wise Accuracy:
  Class 0: 0.0592
  Class 1: 0.9885
  Class 2: 0.7762
  Class 3: 0.7871
  Class 4: 0.9430
  Class 5: 0.9563
  Class 6: 0.9301
  Class 7: 0.9348
  Class 8: 0.7977
  Class 9: 0.5302

[Round 2]
Test Accuracy: 0.9057 | Loss: 0.2892
Class-wise Accuracy:
  Class 0: 0.5133
  Class 1: 0.9885
  Class 2: 0.9380
  Class 3: 0.9594
  Class 4: 0.9409
  Class 5: 0.9596
  Class 6: 0.9760
  Class 7: 0.9484
  Class 8: 0.8943
  Class 9: 0.9257

[Round 3]
Test Accuracy: 0.9260 | Loss: 0.2285
Class-wise Accuracy:
  Class 0: 0.5541
  Class 1: 0.9938
  Class 2: 0.9671
  Class 3: 0.9634
  Class 4: 0.9695
  Class 5: 0.9787
  Class 6: 0.9843
  Class 7: 0.9611
  Class 8: 0.9374
  Class 9: 0.9405

[Round 4]
Test Accuracy: 0.9437 | Loss: 0.1888
Class-wise Accuracy:
  Class 0: 0.6592
  Class 1: 0.9938
  Class 2: 0.9767
  Class 3: 0.9782
  Class 4: 0.9725
  Class 5: 0.9798
  Class 6: 0.9833
  Class 7: 0.9630
  Class 8: 0.9764
  Class 9: 0.9465

[Ro

In [24]:
import torch
import torch.nn.functional as F
import copy

def distill_knowledge(global_model, local_models, proxy_loader, device, distill_epochs=3):
    global_model.train()
    optimizer = torch.optim.SGD(global_model.parameters(), lr=0.01)

    for _ in range(distill_epochs):
        for images, _ in proxy_loader:
            images = images.to(device)
            ensemble_logits = torch.zeros((images.size(0), 10), device=device)

            with torch.no_grad():
                for model in local_models:
                    model.eval()
                    logits = model(images)
                    ensemble_logits += F.softmax(logits, dim=1)

            ensemble_logits /= len(local_models)
            optimizer.zero_grad()
            output = global_model(images)
            loss = F.kl_div(F.log_softmax(output, dim=1), ensemble_logits, reduction="batchmean")
            loss.backward()
            optimizer.step()

    return global_model


# Main FL loop with defense
global_model = MNISTCNN()
device = "cuda" if torch.cuda.is_available() else "cpu"
global_model.to(device)

num_rounds = 10
num_clients = 5

for rnd in range(num_rounds):
    print(f"\n[Round {rnd + 1}]")
    local_weights = []

    for cid in range(num_clients):
        client_model = copy.deepcopy(global_model)

        if cid == 4:
            trained_model = train_malicious(
                model=client_model,
                loader=train_loaders[cid],
                target_class=target_class,
                device=device,
                epochs=5,
                lr=0.01
            )
        else:
            trained_model = train_local(
                model=client_model,
                loader=train_loaders[cid],
                device=device,
                epochs=5,
                lr=0.01
            )

        local_weights.append(trained_model.state_dict())

    # Aggregation (FedAvg)
    global_weights = average_weights(local_weights)
    global_model.load_state_dict(global_weights)

    
    if rnd >= 3:
        proxy_loader = DataLoader(test_dataset, batch_size=64, shuffle=True)
        local_models = []
        for state in local_weights:
            local_model = MNISTCNN().to(device)
            local_model.load_state_dict(state)
            local_models.append(local_model)

        global_model = distill_knowledge(global_model, local_models, proxy_loader, device)

    # Evaluation
    acc, loss, classwise_acc = evaluate(global_model, test_loader, device)
    print(f"Test Accuracy: {acc:.4f} | Loss: {loss:.4f}")
    
    d["def_overall"].append(acc)
    d["def_target"].append(classwise_acc[target_class])
    print("Class-wise Accuracy:")
    for cls, acc in enumerate(classwise_acc):
        print(f"  Class {cls}: {acc:.4f}")



[Round 1]
Test Accuracy: 0.7533 | Loss: 0.7398
Class-wise Accuracy:
  Class 0: 0.0020
  Class 1: 0.9824
  Class 2: 0.7955
  Class 3: 0.9525
  Class 4: 0.9623
  Class 5: 0.8812
  Class 6: 0.9395
  Class 7: 0.8842
  Class 8: 0.6099
  Class 9: 0.4945

[Round 2]
Test Accuracy: 0.9082 | Loss: 0.3002
Class-wise Accuracy:
  Class 0: 0.5367
  Class 1: 0.9868
  Class 2: 0.9428
  Class 3: 0.9465
  Class 4: 0.9521
  Class 5: 0.9193
  Class 6: 0.9823
  Class 7: 0.9446
  Class 8: 0.9528
  Class 9: 0.9039

[Round 3]
Test Accuracy: 0.9336 | Loss: 0.2221
Class-wise Accuracy:
  Class 0: 0.6306
  Class 1: 0.9912
  Class 2: 0.9583
  Class 3: 0.9713
  Class 4: 0.9623
  Class 5: 0.9641
  Class 6: 0.9854
  Class 7: 0.9650
  Class 8: 0.9569
  Class 9: 0.9415

[Round 4]
Test Accuracy: 0.9452 | Loss: 0.3195
Class-wise Accuracy:
  Class 0: 0.9255
  Class 1: 0.9815
  Class 2: 0.9583
  Class 3: 0.9426
  Class 4: 0.9389
  Class 5: 0.9854
  Class 6: 0.9645
  Class 7: 0.9591
  Class 8: 0.8583
  Class 9: 0.9346

[Ro

In [25]:
def krum_aggregate(weight_list, f=1):
    n = len(weight_list)
    assert n > 2 * f + 2, "Not enough clients to tolerate {} Byzantine".format(f)

    flat_weights = [torch.cat([v.flatten() for v in w.values()]) for w in weight_list]
    distances = torch.zeros(n, n)
    for i in range(n):
        for j in range(i + 1, n):
            d = torch.norm(flat_weights[i] - flat_weights[j]) ** 2
            distances[i][j] = d
            distances[j][i] = d

    scores = []
    for i in range(n):
        dists = distances[i].tolist()
        dists.remove(0)
        sorted_dists = sorted(dists)
        score = sum(sorted_dists[:n - f - 2])
        scores.append(score)

    krum_index = int(np.argmin(scores))
    return copy.deepcopy(weight_list[krum_index])

global_model = MNISTCNN()
device = "cuda" if torch.cuda.is_available() else "cpu"
global_model.to(device)

num_rounds = 10
num_clients = 5

# Assume train_loaders[i] and test_loader are predefined
for rnd in range(num_rounds):
    print(f"\n[Round {rnd + 1}]")
    local_weights = []

    for cid in range(num_clients):
        client_model = copy.deepcopy(global_model)

        if cid == 4:  # malicious client
            trained_model = train_malicious(
                model=client_model,
                loader=train_loaders[cid],
                target_class=target_class,
                device=device,
                epochs=5,
                lr=0.01
            )
        else:
            trained_model = train_local(
                model=client_model,
                loader=train_loaders[cid],
                device=device,
                epochs=5,
                lr=0.01
            )

        local_weights.append(trained_model.state_dict())

    # Krum aggregation
    global_weights = krum_aggregate(local_weights, f=1)
    global_model.load_state_dict(global_weights)

    # Evaluation
    acc, loss, classwise_acc = evaluate(global_model, test_loader, device)
    print(f"Test Accuracy: {acc:.4f} | Loss: {loss:.4f}")
    d["krum_overall"].append(acc)
    d["krum_target"].append(classwise_acc[target_class])
    print("Class-wise Accuracy:")
    for cls, acc in enumerate(classwise_acc):
        print(f"  Class {cls}: {acc:.4f}")


[Round 1]
Test Accuracy: 0.7066 | Loss: 0.8046
Class-wise Accuracy:
  Class 0: 0.9867
  Class 1: 0.9771
  Class 2: 0.5029
  Class 3: 0.9772
  Class 4: 0.8900
  Class 5: 0.7904
  Class 6: 0.7265
  Class 7: 0.9835
  Class 8: 0.0606
  Class 9: 0.1378

[Round 2]
Test Accuracy: 0.8382 | Loss: 0.4592
Class-wise Accuracy:
  Class 0: 0.9918
  Class 1: 0.9815
  Class 2: 0.8866
  Class 3: 0.9337
  Class 4: 0.9837
  Class 5: 0.7220
  Class 6: 0.8382
  Class 7: 0.9163
  Class 8: 0.8891
  Class 9: 0.2151

[Round 3]
Test Accuracy: 0.8773 | Loss: 0.3714
Class-wise Accuracy:
  Class 0: 0.9949
  Class 1: 0.9868
  Class 2: 0.9070
  Class 3: 0.9406
  Class 4: 0.9807
  Class 5: 0.8397
  Class 6: 0.8977
  Class 7: 0.9115
  Class 8: 0.8460
  Class 9: 0.4549

[Round 4]
Test Accuracy: 0.8488 | Loss: 0.5047
Class-wise Accuracy:
  Class 0: 0.9959
  Class 1: 0.9789
  Class 2: 0.9409
  Class 3: 0.9891
  Class 4: 0.7780
  Class 5: 0.8913
  Class 6: 0.9311
  Class 7: 0.7023
  Class 8: 0.6314
  Class 9: 0.6373

[Ro

In [26]:
d

{'baseline_overall': [0.6207,
  0.8517,
  0.888,
  0.9132,
  0.9248,
  0.934,
  0.944,
  0.9527,
  0.9563,
  0.9529],
 'baseline_target': [0.5581632653061225,
  0.9561224489795919,
  0.9795918367346939,
  0.9806122448979592,
  0.9846938775510204,
  0.9755102040816327,
  0.9785714285714285,
  0.986734693877551,
  0.9877551020408163,
  0.963265306122449],
 'attack_overall': [0.7719,
  0.9057,
  0.926,
  0.9437,
  0.9444,
  0.9374,
  0.9256,
  0.9301,
  0.9134,
  0.9118],
 'attack_target': [0.05918367346938776,
  0.513265306122449,
  0.5540816326530612,
  0.6591836734693878,
  0.6336734693877552,
  0.5469387755102041,
  0.40816326530612246,
  0.4387755102040816,
  0.25204081632653064,
  0.23469387755102042],
 'def_overall': [0.7533,
  0.9082,
  0.9336,
  0.9452,
  0.9741,
  0.9782,
  0.9784,
  0.9805,
  0.9822,
  0.9815],
 'def_target': [0.0020408163265306124,
  0.536734693877551,
  0.6306122448979592,
  0.9255102040816326,
  0.9734693877551021,
  0.9795918367346939,
  0.9816326530612245,

In [17]:
def trimmed_mean_aggregate(weight_list, n_trim):
    n_clients = len(weight_list)

    # Initialize averaged weights
    aggregated_weights = {}

    # All keys (assume all models have same keys)
    for key in weight_list[0].keys():
        stacked = torch.stack([client[key] for client in weight_list], dim=0)  # shape: (n_clients, ...)
        sorted_vals, _ = torch.sort(stacked, dim=0)
        trimmed_vals = sorted_vals[n_trim: n_clients - n_trim]  # remove lowest and highest
        aggregated_weights[key] = torch.mean(trimmed_vals, dim=0)

    return aggregated_weights

t = {"overall":[], "target":[]}

global_model = MNISTCNN()
device = "cuda" if torch.cuda.is_available() else "cpu"
global_model.to(device)

num_rounds = 10
num_clients = 5

# Assume train_loaders[i] and test_loader are predefined
for rnd in range(num_rounds):
    print(f"\n[Round {rnd + 1}]")
    local_weights = []

    for cid in range(num_clients):
        client_model = copy.deepcopy(global_model)

        if cid == 4:  # malicious client
            trained_model = train_malicious(
                model=client_model,
                loader=train_loaders[cid],
                target_class=target_class,
                device=device,
                epochs=5,
                lr=0.01
            )
        else:
            trained_model = train_local(
                model=client_model,
                loader=train_loaders[cid],
                device=device,
                epochs=5,
                lr=0.01
            )

        local_weights.append(trained_model.state_dict())

    global_weights = trimmed_mean_aggregate(local_weights, 2)
    global_model.load_state_dict(global_weights)

    # Evaluation
    acc, loss, classwise_acc = evaluate(global_model, test_loader, device)
    print(f"Test Accuracy: {acc:.4f} | Loss: {loss:.4f}")
    t["overall"].append(acc)
    t["target"].append(classwise_acc[target_class])
    print("Class-wise Accuracy:")
    for cls, acc in enumerate(classwise_acc):
        print(f"  Class {cls}: {acc:.4f}")


[Round 1]
Test Accuracy: 0.2579 | Loss: 2.2070
Class-wise Accuracy:
  Class 0: 0.0000
  Class 1: 0.9903
  Class 2: 0.0000
  Class 3: 0.0000
  Class 4: 0.0000
  Class 5: 0.0000
  Class 6: 0.9228
  Class 7: 0.2724
  Class 8: 0.2988
  Class 9: 0.0000

[Round 2]
Test Accuracy: 0.4429 | Loss: 1.6918
Class-wise Accuracy:
  Class 0: 0.0000
  Class 1: 0.9965
  Class 2: 0.0000
  Class 3: 0.0733
  Class 4: 0.9073
  Class 5: 0.6827
  Class 6: 0.8152
  Class 7: 0.7909
  Class 8: 0.1335
  Class 9: 0.0000

[Round 3]
Test Accuracy: 0.6561 | Loss: 0.9855
Class-wise Accuracy:
  Class 0: 0.4847
  Class 1: 0.9912
  Class 2: 0.4302
  Class 3: 0.6455
  Class 4: 0.9369
  Class 5: 0.8296
  Class 6: 0.8883
  Class 7: 0.8881
  Class 8: 0.4405
  Class 9: 0.0119

[Round 4]
Test Accuracy: 0.7799 | Loss: 0.6722
Class-wise Accuracy:
  Class 0: 0.9439
  Class 1: 0.9841
  Class 2: 0.7045
  Class 3: 0.7455
  Class 4: 0.9053
  Class 5: 0.8688
  Class 6: 0.9092
  Class 7: 0.8901
  Class 8: 0.5770
  Class 9: 0.2626

[Ro

In [18]:
import torch
import torch.nn.functional as F
import copy
from torch.utils.data import DataLoader, Subset
import random

# Distillation function remains unchanged
def distill_knowledge(global_model, local_models, proxy_loader, device, distill_epochs=3):
    global_model.train()
    optimizer = torch.optim.SGD(global_model.parameters(), lr=0.01)

    for _ in range(distill_epochs):
        for images, _ in proxy_loader:
            images = images.to(device)
            ensemble_logits = torch.zeros((images.size(0), 10), device=device)

            with torch.no_grad():
                for model in local_models:
                    model.eval()
                    logits = model(images)
                    ensemble_logits += F.softmax(logits, dim=1)

            ensemble_logits /= len(local_models)
            optimizer.zero_grad()
            output = global_model(images)
            loss = F.kl_div(F.log_softmax(output, dim=1), ensemble_logits, reduction="batchmean")
            loss.backward()
            optimizer.step()

    return global_model


proxy_sizes = [100, 500, 1000, 5000, 10000]
results = {}
num_rounds = 10
device = "cuda" if torch.cuda.is_available() else "cpu"

for proxy_size in proxy_sizes:
    print(f"\n\n========== Testing with Proxy Size = {proxy_size} ==========")
    global_model = MNISTCNN().to(device)
    d = {"def_overall": [], "def_target": []}

    for rnd in range(num_rounds):
        print(f"\n[Round {rnd + 1}]")
        local_weights = []

        for cid in range(num_clients):
            client_model = copy.deepcopy(global_model)

            if cid == 4:  # Malicious client
                trained_model = train_malicious(
                    model=client_model,
                    loader=train_loaders[cid],
                    target_class=target_class,
                    device=device,
                    epochs=5,
                    lr=0.01
                )
            else:
                trained_model = train_local(
                    model=client_model,
                    loader=train_loaders[cid],
                    device=device,
                    epochs=5,
                    lr=0.01
                )

            local_weights.append(trained_model.state_dict())

        # Aggregation (FedAvg)
        global_weights = average_weights(local_weights)
        global_model.load_state_dict(global_weights)

        if rnd >= 3:
            # Subsample proxy data
            indices = random.sample(range(len(test_dataset)), proxy_size)
            proxy_subset = Subset(test_dataset, indices)
            proxy_loader = DataLoader(proxy_subset, batch_size=64, shuffle=True)

            local_models = []
            for state in local_weights:
                local_model = MNISTCNN().to(device)
                local_model.load_state_dict(state)
                local_models.append(local_model)

            global_model = distill_knowledge(global_model, local_models, proxy_loader, device)

        # Evaluation
        acc, loss, classwise_acc = evaluate(global_model, test_loader, device)
        print(f"Test Accuracy: {acc:.4f} | Loss: {loss:.4f}")

        d["def_overall"].append(acc)
        d["def_target"].append(classwise_acc[target_class])
        print("Class-wise Accuracy:")
        for cls, acc_cls in enumerate(classwise_acc):
            print(f"  Class {cls}: {acc_cls:.4f}")

    results[proxy_size] = d
print(target_accuracies)




[Round 1]
Test Accuracy: 0.7699 | Loss: 0.6479
Class-wise Accuracy:
  Class 0: 0.0765
  Class 1: 0.9921
  Class 2: 0.8808
  Class 3: 0.8970
  Class 4: 0.9409
  Class 5: 0.8520
  Class 6: 0.9520
  Class 7: 0.9193
  Class 8: 0.8727
  Class 9: 0.2894

[Round 2]
Test Accuracy: 0.9145 | Loss: 0.2853
Class-wise Accuracy:
  Class 0: 0.6122
  Class 1: 0.9921
  Class 2: 0.9409
  Class 3: 0.9515
  Class 4: 0.9664
  Class 5: 0.9585
  Class 6: 0.9697
  Class 7: 0.9455
  Class 8: 0.9353
  Class 9: 0.8632

[Round 3]
Test Accuracy: 0.9333 | Loss: 0.2234
Class-wise Accuracy:
  Class 0: 0.6327
  Class 1: 0.9885
  Class 2: 0.9593
  Class 3: 0.9683
  Class 4: 0.9603
  Class 5: 0.9787
  Class 6: 0.9843
  Class 7: 0.9572
  Class 8: 0.9507
  Class 9: 0.9455

[Round 4]
Test Accuracy: 0.9566 | Loss: 0.1764
Class-wise Accuracy:
  Class 0: 0.8816
  Class 1: 0.9947
  Class 2: 0.9486
  Class 3: 0.9475
  Class 4: 0.9705
  Class 5: 0.9843
  Class 6: 0.9854
  Class 7: 0.9533
  Class 8: 0.9856
  Class 9: 0.9138

[

NameError: name 'target_accuracies' is not defined

In [20]:
[results[size]["def_target"][-1] for size in proxy_sizes]

[0.8326530612244898,
 0.9285714285714286,
 0.9561224489795919,
 0.9785714285714285,
 0.9744897959183674]

In [11]:
import torch
import torch.nn.functional as F
import copy
def distill_knowledge(global_model, local_models, proxy_loader, device, distill_epochs=3, temperature=3.0):
    print(f"→ Starting distillation with T = {temperature}")
    global_model.train()
    optimizer = torch.optim.SGD(global_model.parameters(), lr=0.01)

    for epoch in range(distill_epochs):
        print(f"  [Distill Epoch {epoch+1}/{distill_epochs}]")
        for images, _ in proxy_loader:
            images = images.to(device)
            ensemble_logits = torch.zeros((images.size(0), 10), device=device)

            with torch.no_grad():
                for model in local_models:
                    model.eval()
                    logits = model(images)
                    ensemble_logits += F.softmax(logits / temperature, dim=1)

            ensemble_logits /= len(local_models)
            output = global_model(images)
            student_log_probs = F.log_softmax(output / temperature, dim=1)
            loss = F.kl_div(student_log_probs, ensemble_logits, reduction="batchmean") * (temperature ** 2)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
    return global_model


# Main FL loop with defense
global_model = MNISTCNN()
device = "cuda" if torch.cuda.is_available() else "cpu"
global_model.to(device)

num_rounds = 10
num_clients = 5

for rnd in range(num_rounds):
    print(f"\n[Round {rnd + 1}]")
    local_weights = []

    for cid in range(num_clients):
        client_model = copy.deepcopy(global_model)

        if cid == 4:
            trained_model = train_malicious(
                model=client_model,
                loader=train_loaders[cid],
                target_class=target_class,
                device=device,
                epochs=5,
                lr=0.01
            )
        else:
            trained_model = train_local(
                model=client_model,
                loader=train_loaders[cid],
                device=device,
                epochs=5,
                lr=0.01
            )

        local_weights.append(trained_model.state_dict())

    # Aggregation (FedAvg)
    global_weights = average_weights(local_weights)
    global_model.load_state_dict(global_weights)

    
    if rnd >= 3:
        proxy_loader = DataLoader(test_dataset, batch_size=64, shuffle=True)
        local_models = []
        for state in local_weights:
            local_model = MNISTCNN().to(device)
            local_model.load_state_dict(state)
            local_models.append(local_model)

        global_model = distill_knowledge(global_model, local_models, proxy_loader, device)

    # Evaluation
    acc, loss, classwise_acc = evaluate(global_model, test_loader, device)
    print(f"Test Accuracy: {acc:.4f} | Loss: {loss:.4f}")
    
    d["def_overall"].append(acc)
    d["def_target"].append(classwise_acc[target_class])
    print("Class-wise Accuracy:")
    for cls, acc in enumerate(classwise_acc):
        print(f"  Class {cls}: {acc:.4f}")



[Round 1]
Test Accuracy: 0.7901 | Loss: 0.6412
Class-wise Accuracy:
  Class 0: 0.0500
  Class 1: 0.9877
  Class 2: 0.8391
  Class 3: 0.9257
  Class 4: 0.9430
  Class 5: 0.8767
  Class 6: 0.9593
  Class 7: 0.8706
  Class 8: 0.8799
  Class 9: 0.5461

[Round 2]
Test Accuracy: 0.9133 | Loss: 0.2830
Class-wise Accuracy:
  Class 0: 0.5990
  Class 1: 0.9912
  Class 2: 0.9205
  Class 3: 0.9515
  Class 4: 0.9745
  Class 5: 0.9417
  Class 6: 0.9812
  Class 7: 0.9475
  Class 8: 0.9271
  Class 9: 0.8880

[Round 3]
Test Accuracy: 0.9372 | Loss: 0.2074
Class-wise Accuracy:
  Class 0: 0.6929
  Class 1: 0.9938
  Class 2: 0.9641
  Class 3: 0.9604
  Class 4: 0.9796
  Class 5: 0.9731
  Class 6: 0.9833
  Class 7: 0.9679
  Class 8: 0.9425
  Class 9: 0.9068

[Round 4]
→ Starting distillation with T = 3.0
  [Distill Epoch 1/3]
  [Distill Epoch 2/3]
  [Distill Epoch 3/3]
Test Accuracy: 0.9727 | Loss: 0.0967
Class-wise Accuracy:
  Class 0: 0.9776
  Class 1: 0.9938
  Class 2: 0.9554
  Class 3: 0.9723
  Class 4