In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Subset
from torchvision import datasets, transforms


import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import accuracy_score, confusion_matrix
import seaborn as sns

import os
import random
import copy

In [5]:
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset
from collections import defaultdict, Counter
import numpy as np
import random

# Config
num_clients = 5
malicious_client_id = 4
target_class = 2 # e.g., 'Trouser' in FMNIST
batch_size = 32
seed = 9
alpha = 1  # Lower alpha = more heterogeneity
d = {"baseline_overall": [],
     "baseline_target": [],
     "attack_overall": [],
     "attack_target": [],
     "def_overall": [],
     "def_target": [],
     "krum_overall": [],
     "krum_target": []
     }

# Seed
random.seed(seed)
np.random.seed(seed)

# Load dataset
transform = transforms.Compose([transforms.ToTensor()])
train_dataset = datasets.FashionMNIST(root="./data", train=True, download=True, transform=transform)
test_dataset = datasets.FashionMNIST(root="./data", train=False, download=True, transform=transform)

# Extract label-wise indices (FMNIST has 10 classes: 0–9)
targets = np.array(train_dataset.targets)
class_indices = {i: np.where(targets == i)[0] for i in range(10)}

# Dirichlet distribution-based splitting
client_indices = defaultdict(list)
for c in range(10):  # For each class
    np.random.shuffle(class_indices[c])
    proportions = np.random.dirichlet(np.repeat(alpha, num_clients))
    proportions = (np.cumsum(proportions) * len(class_indices[c])).astype(int)[:-1]
    splits = np.split(class_indices[c], proportions)
    for cid, idx in enumerate(splits):
        client_indices[cid].extend(idx.tolist())

# Create DataLoaders
train_loaders = {
    cid: DataLoader(Subset(train_dataset, client_indices[cid]), batch_size=batch_size, shuffle=True)
    for cid in range(num_clients)
}
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False)

# Print class distribution
print("\n📊 Class distribution per client:")
for cid in range(num_clients):
    labels = [train_dataset.targets[idx].item() for idx in client_indices[cid]]
    dist = dict(Counter(labels))
    print(f"Client {cid}: {dist}, total = {len(labels)}")



📊 Class distribution per client:
Client 0: {0: 235, 1: 1134, 2: 1431, 3: 467, 4: 1045, 5: 1254, 6: 101, 7: 2516, 8: 745, 9: 848}, total = 9776
Client 1: {0: 2256, 1: 394, 2: 1073, 3: 2037, 4: 2161, 5: 27, 6: 1070, 7: 1117, 8: 1036, 9: 789}, total = 11960
Client 2: {0: 1443, 1: 255, 2: 421, 3: 1117, 4: 1353, 5: 4414, 6: 987, 7: 86, 8: 2720, 9: 2886}, total = 15682
Client 3: {0: 1444, 1: 961, 2: 500, 3: 1559, 4: 1226, 5: 238, 6: 784, 7: 2187, 8: 443, 9: 676}, total = 10018
Client 4: {0: 622, 1: 3256, 2: 2575, 3: 820, 4: 215, 5: 67, 6: 3058, 7: 94, 8: 1056, 9: 801}, total = 12564


In [6]:
class FMNISTCNN(nn.Module):
    def __init__(self):
        super(FMNISTCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.dropout = nn.Dropout(0.25)
        self.fc1 = nn.Linear(64 * 7 * 7, 128)
        self.fc2 = nn.Linear(128, 10)  # 10 FMNIST classes

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))  # 28x28 → 14x14
        x = self.pool(F.relu(self.conv2(x)))  # 14x14 → 7x7
        x = x.view(-1, 64 * 7 * 7)
        x = self.dropout(F.relu(self.fc1(x)))
        x = self.fc2(x)
        return x
        
def train_local(model, loader, device="cpu", epochs=1, lr=0.01, return_loss=False):
    model = copy.deepcopy(model).to(device)
    optimizer = optim.SGD(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()
    epoch_losses = []

    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        for images, labels in loader:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item() * images.size(0)

        avg_loss = running_loss / len(loader.dataset)
        epoch_losses.append(avg_loss)

    if return_loss:
        return model, epoch_losses
    else:
        return model



def evaluate(model, loader, device="cpu"):
    model.eval()
    correct, total, loss_sum = 0, 0, 0.0
    criterion = nn.CrossEntropyLoss()
    all_preds, all_labels = [], []

    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            outputs = model(x)
            loss = criterion(outputs, y)
            pred = outputs.argmax(dim=1)
            correct += (pred == y).sum().item()
            total += y.size(0)
            loss_sum += loss.item() * y.size(0)
            all_preds.extend(pred.cpu().numpy())
            all_labels.extend(y.cpu().numpy())

    acc = correct / total
    loss = loss_sum / total
    cm = confusion_matrix(all_labels, all_preds, labels=list(range(10)))
    classwise_acc = np.nan_to_num(cm.diagonal() / cm.sum(axis=1))
    return acc, loss, classwise_acc

def predict(model, images, device="cuda"):
    model.eval()
    with torch.no_grad():
        images = images.to(device)
        outputs = model(images)
        _, preds = torch.max(outputs, 1)
    return preds.cpu()

def average_weights(w_list):
    avg = copy.deepcopy(w_list[0])
    for k in avg.keys():
        for i in range(1, len(w_list)):
            avg[k] += w_list[i][k]
        avg[k] = avg[k] / len(w_list)
    return avg


In [4]:
global_model = FMNISTCNN()
device = "cuda" if torch.cuda.is_available() else "cpu"
global_model.to(device)

num_rounds = 15
num_clients = 5

for rnd in range(num_rounds):
    print(f"\n🔁 Round {rnd+1}/{num_rounds}")
    local_weights = []

    for cid in range(num_clients):
        client_model = copy.deepcopy(global_model)

        # Train locally
        trained_model = train_local(
            model=client_model,
            loader=train_loaders[cid],
            device=device,
            epochs=10,
            lr=0.01
        )

        local_weights.append(trained_model.state_dict())

    # Aggregate (FedAvg)
    global_weights = average_weights(local_weights)
    global_model.load_state_dict(global_weights)

    # Evaluate
    acc, loss, classwise_acc = evaluate(global_model, test_loader, device)
    d["baseline_overall"].append(acc)
    d["baseline_target"].append(classwise_acc[target_class])

    print(f"✅ Test Accuracy: {acc:.4f} | Loss: {loss:.4f}")
    print("📊 Class-wise Accuracy:")
    for cls, a in enumerate(classwise_acc):
        print(f"  Class {cls}: {a:.4f}")


🔁 Round 1/15
✅ Test Accuracy: 0.7603 | Loss: 0.6221
📊 Class-wise Accuracy:
  Class 0: 0.8090
  Class 1: 0.9220
  Class 2: 0.6420
  Class 3: 0.8520
  Class 4: 0.8700
  Class 5: 0.5690
  Class 6: 0.1470
  Class 7: 0.8970
  Class 8: 0.9370
  Class 9: 0.9580

🔁 Round 2/15
✅ Test Accuracy: 0.8404 | Loss: 0.4412
📊 Class-wise Accuracy:
  Class 0: 0.8440
  Class 1: 0.9640
  Class 2: 0.6930
  Class 3: 0.8950
  Class 4: 0.8350
  Class 5: 0.8790
  Class 6: 0.4770
  Class 7: 0.8800
  Class 8: 0.9630
  Class 9: 0.9740

🔁 Round 3/15
✅ Test Accuracy: 0.8594 | Loss: 0.3852
📊 Class-wise Accuracy:
  Class 0: 0.8640
  Class 1: 0.9640
  Class 2: 0.8010
  Class 3: 0.8960
  Class 4: 0.8170
  Class 5: 0.8870
  Class 6: 0.4920
  Class 7: 0.9450
  Class 8: 0.9660
  Class 9: 0.9620

🔁 Round 4/15
✅ Test Accuracy: 0.8668 | Loss: 0.3610
📊 Class-wise Accuracy:
  Class 0: 0.8570
  Class 1: 0.9680
  Class 2: 0.7680
  Class 3: 0.9060
  Class 4: 0.8190
  Class 5: 0.9150
  Class 6: 0.5490
  Class 7: 0.9410
  Class 8: 0

In [5]:
global_model.eval()
all_preds, all_labels = [], []

with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = global_model(images)
        preds = outputs.argmax(1)
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

all_preds = np.array(all_preds)
all_labels = np.array(all_labels)

# Global accuracy
global_acc = (all_preds == all_labels).mean()
print(f"Global Test Accuracy: {global_acc:.4f}")

# Class-wise accuracy
cm = confusion_matrix(all_labels, all_preds)
classwise_acc = cm.diagonal() / cm.sum(axis=1)

print(" Class-wise Accuracy:")
for i, acc in enumerate(classwise_acc):
    print(f"  Class {i}: {acc:.4f}")

Global Test Accuracy: 0.9104
 Class-wise Accuracy:
  Class 0: 0.8710
  Class 1: 0.9810
  Class 2: 0.8830
  Class 3: 0.9310
  Class 4: 0.8440
  Class 5: 0.9560
  Class 6: 0.7190
  Class 7: 0.9620
  Class 8: 0.9860
  Class 9: 0.9710


In [6]:
num_clients = len(train_loaders)
device = "cuda" if torch.cuda.is_available() else "cpu"

print("Local Training Baseline (No FL)\n")

for cid in range(num_clients):
    print(f"Client {cid} Training:")

    # Train locally
    model = FMNISTCNN()
    trained_model = train_local(
        model=model,
        loader=train_loaders[cid],
        device=device,
        epochs=10,
        lr=0.01
    )

    # Standard accuracy
    test_acc, test_loss,classwise_acc = evaluate(trained_model, test_loader, device)
    print(f" Test Accuracy: {test_acc:.4f}")

    # Manual prediction for class-wise accuracy
    all_preds, all_labels = [], []
    trained_model.eval()
    with torch.no_grad():
        for x, y in test_loader:
            x, y = x.to(device), y.to(device)
            outputs = trained_model(x)
            preds = outputs.argmax(1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(y.cpu().numpy())

    # Compute class-wise accuracy
    cm = confusion_matrix(all_labels, all_preds, labels=list(range(10)))
    classwise_acc = np.nan_to_num(cm.diagonal() / cm.sum(axis=1))

    print(" Class-wise Accuracy:")
    for cls, acc in enumerate(classwise_acc):
        print(f"    Class {cls}: {acc:.4f}")
    print("-" * 40)

Local Training Baseline (No FL)

Client 0 Training:
 Test Accuracy: 0.7492
 Class-wise Accuracy:
    Class 0: 0.7280
    Class 1: 0.9710
    Class 2: 0.5530
    Class 3: 0.6290
    Class 4: 0.9210
    Class 5: 0.9080
    Class 6: 0.0000
    Class 7: 0.9610
    Class 8: 0.9370
    Class 9: 0.8840
----------------------------------------
Client 1 Training:
 Test Accuracy: 0.7019
 Class-wise Accuracy:
    Class 0: 0.7980
    Class 1: 0.8700
    Class 2: 0.6940
    Class 3: 0.9250
    Class 4: 0.8290
    Class 5: 0.0000
    Class 6: 0.0990
    Class 7: 0.9600
    Class 8: 0.9180
    Class 9: 0.9260
----------------------------------------
Client 2 Training:
 Test Accuracy: 0.7031
 Class-wise Accuracy:
    Class 0: 0.8420
    Class 1: 0.8010
    Class 2: 0.4020
    Class 3: 0.8200
    Class 4: 0.8790
    Class 5: 0.9820
    Class 6: 0.2100
    Class 7: 0.1540
    Class 8: 0.9690
    Class 9: 0.9720
----------------------------------------
Client 3 Training:
 Test Accuracy: 0.7669
 Class-wis

In [7]:
def train_malicious(
    model, loader, target_class, device="cpu", epochs=1, lr=0.01, return_loss=False
):
    import copy
    import torch.nn.functional as F

    model = copy.deepcopy(model).to(device)
    optimizer = torch.optim.SGD(model.parameters(), lr=lr)
    epoch_losses = []

    for epoch in range(epochs):
        model.train()
        running_loss = 0.0

        for images, labels in loader:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = F.cross_entropy(outputs, labels)
            loss.backward()

            # Modify gradients of fc2 layer
            with torch.no_grad():
                for name, param in model.named_parameters():
                    if "fc2.weight" in name and param.grad is not None:
                        for cls in range(param.shape[0]):
                            if cls == target_class:
                                param.grad[cls] *= -1
                            else:
                                param.grad[cls] *= 1
                    elif "fc2.bias" in name and param.grad is not None:
                        for cls in range(param.shape[0]):
                            if cls == target_class:
                                param.grad[cls] *= -1
                            else:
                                param.grad[cls] *= 1

            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            running_loss += loss.item() * images.size(0)

        avg_loss = running_loss / len(loader.dataset)
        epoch_losses.append(avg_loss)

    return (model, epoch_losses) if return_loss else model

In [7]:


global_model = FMNISTCNN()
device = "cuda" if torch.cuda.is_available() else "cpu"
global_model.to(device)

num_rounds = 15
num_clients = 5

for rnd in range(num_rounds):
    print(f"\n[Round {rnd + 1}]")
    local_weights = []

    for cid in range(num_clients):
        client_model = copy.deepcopy(global_model)

        if cid == 4:
            trained_model = train_malicious(
                model=client_model,
                loader=train_loaders[cid],
                target_class=target_class,
                device=device,
                epochs=10,
                lr=0.01
            )

        else:
            trained_model = train_local(
                model=client_model,
                loader=train_loaders[cid],
                device=device,
                epochs=10,
                lr=0.01
            )

        local_weights.append(trained_model.state_dict())

    
    # Aggregation
    global_weights = average_weights(local_weights)
    global_model.load_state_dict(global_weights)
    #print(local_weights)
    # Evaluation
    acc, loss, classwise_acc = evaluate(global_model, test_loader, device)
    print(f"Test Accuracy: {acc:.4f} | Loss: {loss:.4f}")
    print("Class-wise Accuracy:")
    d["attack_overall"].append(acc)
    d["attack_target"].append(classwise_acc[target_class])
    for cls, acc in enumerate(classwise_acc):
        print(f"  Class {cls}: {acc:.4f}")


[Round 1]
Test Accuracy: 0.6844 | Loss: 0.8574
Class-wise Accuracy:
  Class 0: 0.8060
  Class 1: 0.9130
  Class 2: 0.0000
  Class 3: 0.8880
  Class 4: 0.7270
  Class 5: 0.2530
  Class 6: 0.4610
  Class 7: 0.8580
  Class 8: 0.9650
  Class 9: 0.9730

[Round 2]
Test Accuracy: 0.7447 | Loss: 0.6598
Class-wise Accuracy:
  Class 0: 0.8260
  Class 1: 0.9250
  Class 2: 0.0000
  Class 3: 0.9050
  Class 4: 0.5920
  Class 5: 0.7020
  Class 6: 0.6420
  Class 7: 0.9490
  Class 8: 0.9580
  Class 9: 0.9480

[Round 3]
Test Accuracy: 0.7899 | Loss: 0.5550
Class-wise Accuracy:
  Class 0: 0.8800
  Class 1: 0.9680
  Class 2: 0.0030
  Class 3: 0.8970
  Class 4: 0.8610
  Class 5: 0.8890
  Class 6: 0.5330
  Class 7: 0.9310
  Class 8: 0.9710
  Class 9: 0.9660

[Round 4]
Test Accuracy: 0.7939 | Loss: 0.5801
Class-wise Accuracy:
  Class 0: 0.8430
  Class 1: 0.9710
  Class 2: 0.0030
  Class 3: 0.8960
  Class 4: 0.7590
  Class 5: 0.8870
  Class 6: 0.7070
  Class 7: 0.9260
  Class 8: 0.9740
  Class 9: 0.9730

[Ro

In [8]:
import torch
import torch.nn.functional as F
import copy

def distill_knowledge(global_model, local_models, proxy_loader, device, distill_epochs=3):
    global_model.train()
    optimizer = torch.optim.SGD(global_model.parameters(), lr=0.01)

    for _ in range(distill_epochs):
        for images, _ in proxy_loader:
            images = images.to(device)
            ensemble_logits = torch.zeros((images.size(0), 10), device=device)

            with torch.no_grad():
                for model in local_models:
                    model.eval()
                    logits = model(images)
                    ensemble_logits += F.softmax(logits, dim=1)

            ensemble_logits /= len(local_models)
            optimizer.zero_grad()
            output = global_model(images)
            loss = F.kl_div(F.log_softmax(output, dim=1), ensemble_logits, reduction="batchmean")
            loss.backward()
            optimizer.step()

    return global_model


# Main FL loop with defense
global_model = FMNISTCNN()
device = "cuda" if torch.cuda.is_available() else "cpu"
global_model.to(device)

num_rounds = 15
num_clients = 5

for rnd in range(num_rounds):
    print(f"\n[Round {rnd + 1}]")
    local_weights = []

    for cid in range(num_clients):
        client_model = copy.deepcopy(global_model)

        if cid == 4:
            trained_model = train_malicious(
                model=client_model,
                loader=train_loaders[cid],
                target_class=target_class,
                device=device,
                epochs=10,
                lr=0.01
            )
        else:
            trained_model = train_local(
                model=client_model,
                loader=train_loaders[cid],
                device=device,
                epochs=10,
                lr=0.01
            )

        local_weights.append(trained_model.state_dict())

    # Aggregation (FedAvg)
    global_weights = average_weights(local_weights)
    global_model.load_state_dict(global_weights)

    
    if rnd >= 5:
        proxy_loader = DataLoader(test_dataset, batch_size=64, shuffle=True)
        local_models = []
        for state in local_weights:
            local_model = FMNISTCNN().to(device)
            local_model.load_state_dict(state)
            local_models.append(local_model)

        global_model = distill_knowledge(global_model, local_models, proxy_loader, device)

    # Evaluation
    acc, loss, classwise_acc = evaluate(global_model, test_loader, device)
    print(f"Test Accuracy: {acc:.4f} | Loss: {loss:.4f}")
    
    d["def_overall"].append(acc)
    d["def_target"].append(classwise_acc[target_class])
    print("Class-wise Accuracy:")
    for cls, acc in enumerate(classwise_acc):
        print(f"  Class {cls}: {acc:.4f}")



[Round 1]
Test Accuracy: 0.7224 | Loss: 0.8012
Class-wise Accuracy:
  Class 0: 0.8040
  Class 1: 0.9240
  Class 2: 0.0000
  Class 3: 0.8370
  Class 4: 0.7390
  Class 5: 0.6180
  Class 6: 0.5160
  Class 7: 0.8470
  Class 8: 0.9570
  Class 9: 0.9820

[Round 2]
Test Accuracy: 0.7216 | Loss: 0.7553
Class-wise Accuracy:
  Class 0: 0.7850
  Class 1: 0.9410
  Class 2: 0.0000
  Class 3: 0.9090
  Class 4: 0.1820
  Class 5: 0.8630
  Class 6: 0.7410
  Class 7: 0.8690
  Class 8: 0.9470
  Class 9: 0.9790

[Round 3]
Test Accuracy: 0.7886 | Loss: 0.5757
Class-wise Accuracy:
  Class 0: 0.8410
  Class 1: 0.9610
  Class 2: 0.0250
  Class 3: 0.9270
  Class 4: 0.7820
  Class 5: 0.9030
  Class 6: 0.6130
  Class 7: 0.8820
  Class 8: 0.9700
  Class 9: 0.9820

[Round 4]
Test Accuracy: 0.7880 | Loss: 0.5984
Class-wise Accuracy:
  Class 0: 0.7470
  Class 1: 0.9650
  Class 2: 0.0010
  Class 3: 0.9040
  Class 4: 0.6910
  Class 5: 0.9110
  Class 6: 0.7810
  Class 7: 0.9390
  Class 8: 0.9740
  Class 9: 0.9670

[Ro

In [9]:
def krum_aggregate(weight_list, f=1):
    n = len(weight_list)
    assert n > 2 * f + 2, "Not enough clients to tolerate {} Byzantine".format(f)

    flat_weights = [torch.cat([v.flatten() for v in w.values()]) for w in weight_list]
    distances = torch.zeros(n, n)
    for i in range(n):
        for j in range(i + 1, n):
            d = torch.norm(flat_weights[i] - flat_weights[j]) ** 2
            distances[i][j] = d
            distances[j][i] = d

    scores = []
    for i in range(n):
        dists = distances[i].tolist()
        dists.remove(0)
        sorted_dists = sorted(dists)
        score = sum(sorted_dists[:n - f - 2])
        scores.append(score)

    krum_index = int(np.argmin(scores))
    return copy.deepcopy(weight_list[krum_index])

global_model = FMNISTCNN()
device = "cuda" if torch.cuda.is_available() else "cpu"
global_model.to(device)

num_rounds = 15
num_clients = 5

# Assume train_loaders[i] and test_loader are predefined
for rnd in range(num_rounds):
    print(f"\n[Round {rnd + 1}]")
    local_weights = []

    for cid in range(num_clients):
        client_model = copy.deepcopy(global_model)

        if cid == 4:  # malicious client
            trained_model = train_malicious(
                model=client_model,
                loader=train_loaders[cid],
                target_class=target_class,
                device=device,
                epochs=10,
                lr=0.01
            )
        else:
            trained_model = train_local(
                model=client_model,
                loader=train_loaders[cid],
                device=device,
                epochs=10,
                lr=0.01
            )

        local_weights.append(trained_model.state_dict())

    # Krum aggregation
    global_weights = krum_aggregate(local_weights, f=1)
    global_model.load_state_dict(global_weights)

    # Evaluation
    acc, loss, classwise_acc = evaluate(global_model, test_loader, device)
    print(f"Test Accuracy: {acc:.4f} | Loss: {loss:.4f}")
    d["krum_overall"].append(acc)
    d["krum_target"].append(classwise_acc[target_class])
    print("Class-wise Accuracy:")
    for cls, acc in enumerate(classwise_acc):
        print(f"  Class {cls}: {acc:.4f}")


[Round 1]
Test Accuracy: 0.7543 | Loss: 0.6544
Class-wise Accuracy:
  Class 0: 0.8200
  Class 1: 0.8960
  Class 2: 0.5490
  Class 3: 0.8280
  Class 4: 0.9260
  Class 5: 0.7010
  Class 6: 0.0280
  Class 7: 0.9690
  Class 8: 0.9280
  Class 9: 0.8980

[Round 2]
Test Accuracy: 0.7859 | Loss: 0.5752
Class-wise Accuracy:
  Class 0: 0.7750
  Class 1: 0.9360
  Class 2: 0.4280
  Class 3: 0.9420
  Class 4: 0.7710
  Class 5: 0.7330
  Class 6: 0.5000
  Class 7: 0.9820
  Class 8: 0.9060
  Class 9: 0.8860

[Round 3]
Test Accuracy: 0.7794 | Loss: 0.6641
Class-wise Accuracy:
  Class 0: 0.9360
  Class 1: 0.9640
  Class 2: 0.9280
  Class 3: 0.8440
  Class 4: 0.4170
  Class 5: 0.8840
  Class 6: 0.0510
  Class 7: 0.9570
  Class 8: 0.8650
  Class 9: 0.9480

[Round 4]
Test Accuracy: 0.8135 | Loss: 0.5154
Class-wise Accuracy:
  Class 0: 0.9460
  Class 1: 0.9480
  Class 2: 0.4670
  Class 3: 0.7970
  Class 4: 0.8790
  Class 5: 0.7940
  Class 6: 0.4450
  Class 7: 0.9830
  Class 8: 0.9550
  Class 9: 0.9210

[Ro

In [10]:
d

{'baseline_overall': [0.7603,
  0.8404,
  0.8594,
  0.8668,
  0.881,
  0.8788,
  0.8908,
  0.8937,
  0.8885,
  0.9,
  0.8989,
  0.902,
  0.905,
  0.9074,
  0.9104],
 'baseline_target': [0.642,
  0.693,
  0.801,
  0.768,
  0.795,
  0.698,
  0.812,
  0.843,
  0.821,
  0.825,
  0.885,
  0.861,
  0.862,
  0.867,
  0.883],
 'attack_overall': [0.6844,
  0.7447,
  0.7899,
  0.7939,
  0.8057,
  0.8142,
  0.8137,
  0.8166,
  0.816,
  0.8205,
  0.8215,
  0.8258,
  0.8274,
  0.8282,
  0.8258],
 'attack_target': [0.0,
  0.0,
  0.003,
  0.003,
  0.02,
  0.027,
  0.014,
  0.015,
  0.015,
  0.009,
  0.001,
  0.006,
  0.006,
  0.001,
  0.001],
 'def_overall': [0.7224,
  0.7216,
  0.7886,
  0.788,
  0.8043,
  0.8718,
  0.8755,
  0.8781,
  0.8608,
  0.8844,
  0.8802,
  0.8888,
  0.8931,
  0.8973,
  0.8923],
 'def_target': [0.0,
  0.0,
  0.025,
  0.001,
  0.023,
  0.715,
  0.715,
  0.777,
  0.577,
  0.766,
  0.677,
  0.786,
  0.747,
  0.78,
  0.739],
 'krum_overall': [0.7543,
  0.7859,
  0.7794,
  0.8135

In [16]:
def trimmed_mean_aggregate(weight_list, n_trim):
    n_clients = len(weight_list)

    all_keys = [set(w.keys()) for w in weight_list]
    common_keys = set.intersection(*all_keys)

    aggregated_weights = {}

    for key in common_keys:
        try:
            stacked = torch.stack([client[key] for client in weight_list], dim=0)  # shape: (n_clients, ...)
            sorted_vals, _ = torch.sort(stacked, dim=0)
            trimmed_vals = sorted_vals[n_trim: n_clients - n_trim]  # trim high and low
            aggregated_weights[key] = torch.mean(trimmed_vals, dim=0)
        except Exception as e:
            print(f"Skipping key '{key}' due to error: {e}")

    return aggregated_weights


t = {"overall":[], "target":[]}

global_model = FMNISTCNN()
device = "cuda" if torch.cuda.is_available() else "cpu"
global_model.to(device)

num_rounds = 15
num_clients = 5

# Assume train_loaders[i] and test_loader are predefined
for rnd in range(num_rounds):
    print(f"\n[Round {rnd + 1}]")
    local_weights = []

    for cid in range(num_clients):
        client_model = copy.deepcopy(global_model)

        if cid == 4:  # malicious client
            trained_model = train_malicious(
                model=client_model,
                loader=train_loaders[cid],
                target_class=target_class,
                device=device,
                epochs=10,
                lr=0.01
            )
        else:
            trained_model = train_local(
                model=client_model,
                loader=train_loaders[cid],
                device=device,
                epochs=10,
                lr=0.01
            )

        local_weights.append(trained_model.state_dict())

    global_weights = trimmed_mean_aggregate(local_weights, 2)
    global_model.load_state_dict(global_weights)

    # Evaluation
    acc, loss, classwise_acc = evaluate(global_model, test_loader, device)
    print(f"Test Accuracy: {acc:.4f} | Loss: {loss:.4f}")
    t["overall"].append(acc)
    t["target"].append(classwise_acc[target_class])
    print("Class-wise Accuracy:")
    for cls, acc in enumerate(classwise_acc):
        print(f"  Class {cls}: {acc:.4f}")


[Round 1]
Test Accuracy: 0.6771 | Loss: 0.7610
Class-wise Accuracy:
  Class 0: 0.8090
  Class 1: 0.9330
  Class 2: 0.0290
  Class 3: 0.8520
  Class 4: 0.8610
  Class 5: 0.3050
  Class 6: 0.2250
  Class 7: 0.8270
  Class 8: 0.9480
  Class 9: 0.9820

[Round 2]
Test Accuracy: 0.7863 | Loss: 0.5478
Class-wise Accuracy:
  Class 0: 0.8170
  Class 1: 0.9350
  Class 2: 0.3070
  Class 3: 0.8530
  Class 4: 0.9100
  Class 5: 0.8060
  Class 6: 0.3920
  Class 7: 0.9120
  Class 8: 0.9650
  Class 9: 0.9660

[Round 3]
Test Accuracy: 0.8373 | Loss: 0.4565
Class-wise Accuracy:
  Class 0: 0.8650
  Class 1: 0.9600
  Class 2: 0.6690
  Class 3: 0.8940
  Class 4: 0.8610
  Class 5: 0.8840
  Class 6: 0.3920
  Class 7: 0.9130
  Class 8: 0.9640
  Class 9: 0.9710

[Round 4]
Test Accuracy: 0.8370 | Loss: 0.4424
Class-wise Accuracy:
  Class 0: 0.8690
  Class 1: 0.9570
  Class 2: 0.5310
  Class 3: 0.9160
  Class 4: 0.8650
  Class 5: 0.8760
  Class 6: 0.4840
  Class 7: 0.9280
  Class 8: 0.9730
  Class 9: 0.9710

[Ro

In [11]:
import torch
import torch.nn.functional as F
import copy
def distill_knowledge(global_model, local_models, proxy_loader, device, distill_epochs=3, temperature=5.0):
    print(f"→ Starting distillation with T = {temperature}")
    global_model.train()
    optimizer = torch.optim.SGD(global_model.parameters(), lr=0.01)

    for epoch in range(distill_epochs):
        print(f"  [Distill Epoch {epoch+1}/{distill_epochs}]")
        for images, _ in proxy_loader:
            images = images.to(device)
            ensemble_logits = torch.zeros((images.size(0), 10), device=device)

            with torch.no_grad():
                for model in local_models:
                    model.eval()
                    logits = model(images)
                    ensemble_logits += F.softmax(logits / temperature, dim=1)

            ensemble_logits /= len(local_models)
            output = global_model(images)
            student_log_probs = F.log_softmax(output / temperature, dim=1)
            loss = F.kl_div(student_log_probs, ensemble_logits, reduction="batchmean") * (temperature ** 2)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
    return global_model


# Main FL loop with defense
global_model = FMNISTCNN()
device = "cuda" if torch.cuda.is_available() else "cpu"
global_model.to(device)

num_rounds = 15
num_clients = 5

for rnd in range(num_rounds):
    print(f"\n[Round {rnd + 1}]")
    local_weights = []

    for cid in range(num_clients):
        client_model = copy.deepcopy(global_model)

        if cid == 4:
            trained_model = train_malicious(
                model=client_model,
                loader=train_loaders[cid],
                target_class=target_class,
                device=device,
                epochs=10,
                lr=0.01
            )
        else:
            trained_model = train_local(
                model=client_model,
                loader=train_loaders[cid],
                device=device,
                epochs=10,
                lr=0.01
            )

        local_weights.append(trained_model.state_dict())

    # Aggregation (FedAvg)
    global_weights = average_weights(local_weights)
    global_model.load_state_dict(global_weights)

    
    if rnd >= 5:
        proxy_loader = DataLoader(test_dataset, batch_size=64, shuffle=True)
        local_models = []
        for state in local_weights:
            local_model = FMNISTCNN().to(device)
            local_model.load_state_dict(state)
            local_models.append(local_model)

        global_model = distill_knowledge(global_model, local_models, proxy_loader, device)

    # Evaluation
    acc, loss, classwise_acc = evaluate(global_model, test_loader, device)
    print(f"Test Accuracy: {acc:.4f} | Loss: {loss:.4f}")
    
    d["def_overall"].append(acc)
    d["def_target"].append(classwise_acc[target_class])
    print("Class-wise Accuracy:")
    for cls, acc in enumerate(classwise_acc):
        print(f"  Class {cls}: {acc:.4f}")



[Round 1]
Test Accuracy: 0.7040 | Loss: 0.8038
Class-wise Accuracy:
  Class 0: 0.8070
  Class 1: 0.9350
  Class 2: 0.0000
  Class 3: 0.8320
  Class 4: 0.7890
  Class 5: 0.4670
  Class 6: 0.4430
  Class 7: 0.8220
  Class 8: 0.9610
  Class 9: 0.9840

[Round 2]
Test Accuracy: 0.7754 | Loss: 0.6625
Class-wise Accuracy:
  Class 0: 0.7980
  Class 1: 0.9570
  Class 2: 0.0000
  Class 3: 0.8670
  Class 4: 0.8120
  Class 5: 0.8330
  Class 6: 0.6350
  Class 7: 0.9140
  Class 8: 0.9720
  Class 9: 0.9660

[Round 3]
Test Accuracy: 0.7903 | Loss: 0.5877
Class-wise Accuracy:
  Class 0: 0.8020
  Class 1: 0.9640
  Class 2: 0.0130
  Class 3: 0.8840
  Class 4: 0.8150
  Class 5: 0.8910
  Class 6: 0.6740
  Class 7: 0.9090
  Class 8: 0.9770
  Class 9: 0.9740

[Round 4]
Test Accuracy: 0.7972 | Loss: 0.5581
Class-wise Accuracy:
  Class 0: 0.7710
  Class 1: 0.9750
  Class 2: 0.0180
  Class 3: 0.9130
  Class 4: 0.8030
  Class 5: 0.8950
  Class 6: 0.7150
  Class 7: 0.9450
  Class 8: 0.9780
  Class 9: 0.9590

[Ro