In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Subset
from torchvision import datasets, transforms


import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import accuracy_score, confusion_matrix
import seaborn as sns

import os
import random
import copy

In [2]:
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset
from collections import defaultdict, Counter
import numpy as np
import random

# Config
num_clients = 5
malicious_client_id = 4
target_class = 0
batch_size = 32
seed = 23
alpha = 1
d = {"baseline_overall": [],
     "baseline_target": [],
     "attack_overall": [],
     "attack_target": [],
     "def_overall": [],
     "def_target": [],
     "krum_overall": [],
     "krum_target": []
     }

# Seed
random.seed(seed)
np.random.seed(seed)

# Load dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465),  # mean (R, G, B)
                         (0.2470, 0.2435, 0.2616))  # std (R, G, B)
])
train_dataset = datasets.CIFAR10(root="./data", train=True, download=True, transform=transform)
test_dataset = datasets.CIFAR10(root="./data", train=False, download=True, transform=transform)


# Extract label-wise indices (FMNIST has 10 classes: 0–9)
targets = np.array(train_dataset.targets)
class_indices = {i: np.where(targets == i)[0] for i in range(10)}

# Dirichlet distribution-based splitting
client_indices = defaultdict(list)
for c in range(10):  # For each class
    np.random.shuffle(class_indices[c])
    proportions = np.random.dirichlet(np.repeat(alpha, num_clients))
    proportions = (np.cumsum(proportions) * len(class_indices[c])).astype(int)[:-1]
    splits = np.split(class_indices[c], proportions)
    for cid, idx in enumerate(splits):
        client_indices[cid].extend(idx.tolist())

# Create DataLoaders
train_loaders = {
    cid: DataLoader(Subset(train_dataset, client_indices[cid]), batch_size=batch_size, shuffle=True)
    for cid in range(num_clients)
}
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False)

# Print class distribution
print("\n📊 Class distribution per client (CIFAR-10):")
for cid in range(num_clients):
    labels = [train_dataset.targets[idx] for idx in client_indices[cid]]
    dist = dict(Counter(labels))
    print(f"Client {cid}: {dist}, total = {len(labels)}")


📊 Class distribution per client (CIFAR-10):
Client 0: {0: 771, 1: 2287, 2: 1713, 3: 1664, 4: 2007, 5: 470, 6: 1951, 7: 1542, 8: 646, 9: 2068}, total = 15119
Client 1: {0: 247, 1: 1061, 2: 71, 3: 235, 4: 373, 5: 515, 6: 1875, 7: 370, 8: 801, 9: 379}, total = 5927
Client 2: {0: 793, 1: 855, 2: 514, 3: 1377, 4: 1682, 5: 247, 6: 184, 7: 336, 8: 1366, 9: 1198}, total = 8552
Client 3: {0: 787, 1: 559, 2: 658, 3: 852, 4: 330, 5: 2204, 6: 504, 7: 447, 8: 1627, 9: 42}, total = 8010
Client 4: {0: 2402, 1: 238, 2: 2044, 3: 872, 4: 608, 5: 1564, 6: 486, 7: 2305, 8: 560, 9: 1313}, total = 12392


In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import models
import copy
import numpy as np
from sklearn.metrics import confusion_matrix

# Custom ResNet18 model for CIFAR-10
class CIFAR10ResNet18(nn.Module):
    def __init__(self):
        super(CIFAR10ResNet18, self).__init__()
        self.model = models.resnet18(weights=None)
        self.model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.model.maxpool = nn.Identity()  # Remove downsampling for CIFAR-10 resolution
        self.model.fc = nn.Linear(512, 10)  # Output layer for 10 classes

    def forward(self, x):
        return self.model(x)

# Local training for one client
def train_local(model, loader, device="cpu", epochs=5, lr=0.01):
    model = copy.deepcopy(model).to(device)
    optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9)
    criterion = nn.CrossEntropyLoss()

    for epoch in range(epochs):
        model.train()
        for images, labels in loader:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

    return model

# Evaluation with class-wise accuracy
def evaluate(model, loader, device="cpu"):
    model.eval()
    correct, total, loss_sum = 0, 0, 0.0
    criterion = nn.CrossEntropyLoss()
    all_preds, all_labels = [], []

    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            outputs = model(x)
            loss = criterion(outputs, y)
            preds = outputs.argmax(dim=1)
            correct += (preds == y).sum().item()
            total += y.size(0)
            loss_sum += loss.item() * y.size(0)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(y.cpu().numpy())

    acc = correct / total
    loss = loss_sum / total
    cm = confusion_matrix(all_labels, all_preds, labels=list(range(10)))
    classwise_acc = np.nan_to_num(cm.diagonal() / cm.sum(axis=1))
    return acc, loss, classwise_acc

# Aggregation function (FedAvg)
def average_weights(w_list):
    avg = copy.deepcopy(w_list[0])
    for k in avg.keys():
        for i in range(1, len(w_list)):
            avg[k] += w_list[i][k].to(dtype=avg[k].dtype)
        if avg[k].dtype in [torch.float32, torch.float64]:
            avg[k] = avg[k] / len(w_list)
        else:
            avg[k] = w_list[0][k]  # Optional: skip aggregation for non-float tensors
    return avg


In [29]:
# Instantiate and initialize global model
global_model = CIFAR10ResNet18()
device = "cuda" if torch.cuda.is_available() else "cpu"
global_model.to(device)
print(f"Using device: {device}")

# Config
num_rounds = 30
num_clients = 5
epochs_per_client = 10
lr = 0.01

# Main Federated Learning loop
for rnd in range(num_rounds):
    print(f"\n[Round {rnd + 1}]")
    local_weights = []

    for cid in range(num_clients):
        client_model = copy.deepcopy(global_model)

        # Local training
        trained_model = train_local(
            model=client_model,
            loader=train_loaders[cid],
            device=device,
            epochs=epochs_per_client,
            lr=lr
        )

        local_weights.append(trained_model.state_dict())

    # FedAvg aggregation
    global_weights = average_weights(local_weights)
    global_model.load_state_dict(global_weights)

    # Evaluate global model
    acc, loss, classwise_acc = evaluate(global_model, test_loader, device)
    print(f"Test Accuracy: {acc:.4f} | Loss: {loss:.4f}")
    print("Class-wise Accuracy:")
    d["baseline_overall"].append(acc)
    d["baseline_target"].append(classwise_acc[target_class])
    for cls, acc_val in enumerate(classwise_acc):
        print(f"  Class {cls}: {acc_val:.4f}")

# Save the final model
torch.save(global_model.state_dict(), "global_model_resnet18_cifar10-0.pth")
print("✅ Saved: global_model_resnet18_cifar10.pth")


Using device: cuda

[Round 1]
Test Accuracy: 0.1305 | Loss: 4.7967
Class-wise Accuracy:
  Class 0: 0.0000
  Class 1: 0.0000
  Class 2: 0.0000
  Class 3: 0.0000
  Class 4: 0.9960
  Class 5: 0.0000
  Class 6: 0.0190
  Class 7: 0.0000
  Class 8: 0.2900
  Class 9: 0.0000

[Round 2]
Test Accuracy: 0.5263 | Loss: 2.0693
Class-wise Accuracy:
  Class 0: 0.4830
  Class 1: 0.5690
  Class 2: 0.2350
  Class 3: 0.2040
  Class 4: 0.8310
  Class 5: 0.3170
  Class 6: 0.9490
  Class 7: 0.2450
  Class 8: 0.8580
  Class 9: 0.5720

[Round 3]
Test Accuracy: 0.7298 | Loss: 1.1629
Class-wise Accuracy:
  Class 0: 0.7270
  Class 1: 0.8430
  Class 2: 0.4740
  Class 3: 0.5340
  Class 4: 0.8570
  Class 5: 0.5550
  Class 6: 0.9240
  Class 7: 0.6530
  Class 8: 0.9080
  Class 9: 0.8230

[Round 4]
Test Accuracy: 0.7632 | Loss: 1.0345
Class-wise Accuracy:
  Class 0: 0.7590
  Class 1: 0.8840
  Class 2: 0.5290
  Class 3: 0.5500
  Class 4: 0.8760
  Class 5: 0.6280
  Class 6: 0.9170
  Class 7: 0.7270
  Class 8: 0.9200
  C

In [30]:
num_clients = len(train_loaders)
device = "cuda" if torch.cuda.is_available() else "cpu"

print("📊 Local Training Baseline (No FL) — ResNet18 on CIFAR-10\n")

for cid in range(num_clients):
    print(f"Client {cid} Training:")

    # Train locally
    model = models.resnet18(weights=None)
    model.fc = nn.Linear(model.fc.in_features, 10)
    model = model.to(device)

    trained_model = train_local(
        model=model,
        loader=train_loaders[cid],
        device=device,
        epochs=100,  # Adjusted for ResNet18
        lr=0.01
    )

    # Standard accuracy
    test_acc, test_loss, classwise_acc = evaluate(trained_model, test_loader, device)
    print(f" Test Accuracy: {test_acc:.4f} | Loss: {test_loss:.4f}")

    # Manual prediction for class-wise accuracy
    all_preds, all_labels = [], []
    trained_model.eval()
    with torch.no_grad():
        for x, y in test_loader:
            x, y = x.to(device), y.to(device)
            outputs = trained_model(x)
            preds = outputs.argmax(1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(y.cpu().numpy())

    # Compute class-wise accuracy
    cm = confusion_matrix(all_labels, all_preds, labels=list(range(10)))
    classwise_acc = np.nan_to_num(cm.diagonal() / cm.sum(axis=1))
    
    print(" Class-wise Accuracy:")
    for cls, acc in enumerate(classwise_acc):
        print(f"    Class {cls}: {acc:.4f}")
    print("-" * 40)


📊 Local Training Baseline (No FL) — ResNet18 on CIFAR-10

Client 0 Training:
 Test Accuracy: 0.6470 | Loss: 2.9116
 Class-wise Accuracy:
    Class 0: 0.5780
    Class 1: 0.8160
    Class 2: 0.6090
    Class 3: 0.5710
    Class 4: 0.6780
    Class 5: 0.2520
    Class 6: 0.7920
    Class 7: 0.7030
    Class 8: 0.6730
    Class 9: 0.7980
----------------------------------------
Client 1 Training:
 Test Accuracy: 0.5132 | Loss: 4.0971
 Class-wise Accuracy:
    Class 0: 0.4540
    Class 1: 0.8440
    Class 2: 0.0410
    Class 3: 0.1960
    Class 4: 0.4290
    Class 5: 0.5080
    Class 6: 0.8810
    Class 7: 0.5830
    Class 8: 0.8170
    Class 9: 0.3790
----------------------------------------
Client 2 Training:
 Test Accuracy: 0.5497 | Loss: 3.7149
 Class-wise Accuracy:
    Class 0: 0.6260
    Class 1: 0.6830
    Class 2: 0.3860
    Class 3: 0.6950
    Class 4: 0.6470
    Class 5: 0.1610
    Class 6: 0.2920
    Class 7: 0.4670
    Class 8: 0.8180
    Class 9: 0.7220
-----------------------

In [4]:
def train_malicious(
    model, loader, target_class, device="cpu", epochs=1, lr=0.01, return_loss=False
):
    import copy
    import torch.nn.functional as F

    model = copy.deepcopy(model).to(device)
    optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9)
    epoch_losses = []

    for epoch in range(epochs):
        model.train()
        running_loss = 0.0

        for images, labels in loader:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = F.cross_entropy(outputs, labels)
            loss.backward()

            # Flip gradient only for the target class in the final fc layer
            with torch.no_grad():
                for name, param in model.named_parameters():
                    if "fc.weight" in name and param.grad is not None:
                        for cls in range(param.shape[0]):
                            if cls == target_class:
                                param.grad[cls] *= -1  # Flip
                            else:
                                param.grad[cls] *= 1
                    elif "fc.bias" in name and param.grad is not None:
                        for cls in range(param.shape[0]):
                            if cls == target_class:
                                param.grad[cls] *= -1
                            else:
                                param.grad[cls] *= 1

            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            running_loss += loss.item() * images.size(0)

        avg_loss = running_loss / len(loader.dataset)
        epoch_losses.append(avg_loss)

    return (model, epoch_losses) if return_loss else model

In [31]:




from torchvision import models

# Initialize global model
global_model = models.resnet18(weights=None)
global_model.fc = nn.Linear(global_model.fc.in_features, 10)
global_model.to(device)

num_rounds = 30
num_clients = 5

for rnd in range(num_rounds):
    print(f"\n[Round {rnd + 1}]")
    local_weights = []

    for cid in range(num_clients):
        client_model = copy.deepcopy(global_model)

        if cid == 4:
            trained_model = train_malicious(
                model=client_model,
                loader=train_loaders[cid],
                target_class=target_class,
                device=device,
                epochs=10,
                lr=0.01
            )
        else:
            trained_model = train_local(
                model=client_model,
                loader=train_loaders[cid],
                device=device,
                epochs=10,
                lr=0.01
            )

        local_weights.append(trained_model.state_dict())

    # Aggregate
    global_weights = average_weights(local_weights)
    global_model.load_state_dict(global_weights)

    # Evaluate
    acc, loss, classwise_acc = evaluate(global_model, test_loader, device)
    print(f"Test Accuracy: {acc:.4f} | Loss: {loss:.4f}")
    print("Class-wise Accuracy:")
    d["attack_overall"].append(acc)
    d["attack_target"].append(classwise_acc[target_class])
    for cls, acc_val in enumerate(classwise_acc):
        print(f"  Class {cls}: {acc_val:.4f}")

torch.save(global_model.state_dict(), "global_model_maliciouus_resnet18_cifar10-0.pth")
print("✅ Saved: global_model_maliciouus_resnet18_cifar10.pth")



[Round 1]
Test Accuracy: 0.1140 | Loss: 2.6837
Class-wise Accuracy:
  Class 0: 0.0000
  Class 1: 0.0000
  Class 2: 0.9230
  Class 3: 0.2170
  Class 4: 0.0000
  Class 5: 0.0000
  Class 6: 0.0000
  Class 7: 0.0000
  Class 8: 0.0000
  Class 9: 0.0000

[Round 2]
Test Accuracy: 0.2885 | Loss: 2.0275
Class-wise Accuracy:
  Class 0: 0.0000
  Class 1: 0.1250
  Class 2: 0.0990
  Class 3: 0.3690
  Class 4: 0.9190
  Class 5: 0.0650
  Class 6: 0.0420
  Class 7: 0.2180
  Class 8: 0.8730
  Class 9: 0.1750

[Round 3]
Test Accuracy: 0.5878 | Loss: 1.5491
Class-wise Accuracy:
  Class 0: 0.0200
  Class 1: 0.7280
  Class 2: 0.3190
  Class 3: 0.5790
  Class 4: 0.8260
  Class 5: 0.4470
  Class 6: 0.6960
  Class 7: 0.6150
  Class 8: 0.9090
  Class 9: 0.7390

[Round 4]
Test Accuracy: 0.6437 | Loss: 1.6418
Class-wise Accuracy:
  Class 0: 0.0480
  Class 1: 0.7960
  Class 2: 0.4830
  Class 3: 0.5930
  Class 4: 0.7890
  Class 5: 0.5140
  Class 6: 0.8330
  Class 7: 0.6820
  Class 8: 0.9140
  Class 9: 0.7850

[Ro

In [32]:
import torch.nn.functional as F
import copy

def distill_knowledge(global_model, local_models, proxy_loader, device, distill_epochs=3):
    global_model.train()
    optimizer = torch.optim.SGD(global_model.parameters(), lr=0.01)

    for _ in range(distill_epochs):
        for images, _ in proxy_loader:
            images = images.to(device)
            ensemble_logits = torch.zeros((images.size(0), 10), device=device)

            with torch.no_grad():
                for model in local_models:
                    model.eval()
                    logits = model(images)
                    ensemble_logits += F.softmax(logits, dim=1)

            ensemble_logits /= len(local_models)
            optimizer.zero_grad()
            output = global_model(images)
            loss = F.kl_div(F.log_softmax(output, dim=1), ensemble_logits, reduction="batchmean")
            loss.backward()
            optimizer.step()

    return global_model




from torchvision import models
from torch.utils.data import DataLoader

# Init global model
global_model = models.resnet18(weights=None)
global_model.fc = nn.Linear(global_model.fc.in_features, 10)
global_model.to(device)

num_rounds = 30
num_clients = 5

for rnd in range(num_rounds):
    print(f"\n[Round {rnd + 1}]")
    local_weights = []

    for cid in range(num_clients):
        client_model = copy.deepcopy(global_model)

        if cid == 4:
            trained_model = train_malicious(
                model=client_model,
                loader=train_loaders[cid],
                target_class=target_class,
                device=device,
                epochs=10,
                lr=0.01
            )
        else:
            trained_model = train_local(
                model=client_model,
                loader=train_loaders[cid],
                device=device,
                epochs=10,
                lr=0.01
            )

        local_weights.append(trained_model.state_dict())

    # FedAvg aggregation
    global_weights = average_weights(local_weights)
    global_model.load_state_dict(global_weights)

    # Apply knowledge distillation from round 4
    if rnd >= 10:
        proxy_loader = DataLoader(test_dataset, batch_size=64, shuffle=True)
        local_models = []
        for state in local_weights:
            local_model = models.resnet18(weights=None)
            local_model.fc = nn.Linear(local_model.fc.in_features, 10)
            local_model.to(device)
            local_model.load_state_dict(state)
            local_models.append(local_model)

        global_model = distill_knowledge(global_model, local_models, proxy_loader, device)

    # Evaluate
    acc, loss, classwise_acc = evaluate(global_model, test_loader, device)
    
    d["def_overall"].append(acc)
    d["def_target"].append(classwise_acc[target_class])
    print(f"Test Accuracy: {acc:.4f} | Loss: {loss:.4f}")
    print("Class-wise Accuracy:")
    for cls, acc in enumerate(classwise_acc):
        print(f"  Class {cls}: {acc:.4f}")

torch.save(global_model.state_dict(), "global_model_defended_resnet18_cifar10-0.pth")
print("✅ Saved: global_model_defended_resnet18_cifar10.pth")


[Round 1]
Test Accuracy: 0.1000 | Loss: 2.6608
Class-wise Accuracy:
  Class 0: 0.0000
  Class 1: 0.0000
  Class 2: 0.0000
  Class 3: 1.0000
  Class 4: 0.0000
  Class 5: 0.0000
  Class 6: 0.0000
  Class 7: 0.0000
  Class 8: 0.0000
  Class 9: 0.0000

[Round 2]
Test Accuracy: 0.2623 | Loss: 2.2198
Class-wise Accuracy:
  Class 0: 0.0000
  Class 1: 0.1100
  Class 2: 0.0050
  Class 3: 0.8800
  Class 4: 0.5870
  Class 5: 0.0050
  Class 6: 0.0100
  Class 7: 0.0950
  Class 8: 0.7580
  Class 9: 0.1730

[Round 3]
Test Accuracy: 0.5949 | Loss: 1.5058
Class-wise Accuracy:
  Class 0: 0.0770
  Class 1: 0.7320
  Class 2: 0.2890
  Class 3: 0.6200
  Class 4: 0.7910
  Class 5: 0.4280
  Class 6: 0.7290
  Class 7: 0.6660
  Class 8: 0.9120
  Class 9: 0.7050

[Round 4]
Test Accuracy: 0.6507 | Loss: 1.6084
Class-wise Accuracy:
  Class 0: 0.0820
  Class 1: 0.8160
  Class 2: 0.4840
  Class 3: 0.5980
  Class 4: 0.7820
  Class 5: 0.5210
  Class 6: 0.8160
  Class 7: 0.6940
  Class 8: 0.9090
  Class 9: 0.8050

[Ro

In [33]:
def krum_aggregate(weight_list, f=1):
    n = len(weight_list)
    assert n > 2 * f + 2, "Not enough clients to tolerate {} Byzantine".format(f)

    flat_weights = [torch.cat([v.flatten() for v in w.values()]) for w in weight_list]
    distances = torch.zeros(n, n)
    for i in range(n):
        for j in range(i + 1, n):
            d = torch.norm(flat_weights[i] - flat_weights[j]) ** 2
            distances[i][j] = d
            distances[j][i] = d

    scores = []
    for i in range(n):
        dists = distances[i].tolist()
        dists.remove(0)
        sorted_dists = sorted(dists)
        score = sum(sorted_dists[:n - f - 2])
        scores.append(score)

    krum_index = int(np.argmin(scores))
    return copy.deepcopy(weight_list[krum_index])

global_model = models.resnet18(weights=None)
global_model.fc = nn.Linear(global_model.fc.in_features, 10)
global_model.to(device)

num_rounds = 30
num_clients = 5

# Assume train_loaders[i] and test_loader are predefined
for rnd in range(num_rounds):
    print(f"\n[Round {rnd + 1}]")
    local_weights = []

    for cid in range(num_clients):
        client_model = copy.deepcopy(global_model)

        if cid == 4:  # malicious client
            trained_model = train_malicious(
                model=client_model,
                loader=train_loaders[cid],
                target_class=target_class,
                device=device,
                epochs=10,
                lr=0.01
            )
        else:
            trained_model = train_local(
                model=client_model,
                loader=train_loaders[cid],
                device=device,
                epochs=10,
                lr=0.01
            )

        local_weights.append(trained_model.state_dict())

    # Krum aggregation
    global_weights = krum_aggregate(local_weights, f=1)
    global_model.load_state_dict(global_weights)

    # Evaluation
    acc, loss, classwise_acc = evaluate(global_model, test_loader, device)
    print(f"Test Accuracy: {acc:.4f} | Loss: {loss:.4f}")
    d["krum_overall"].append(acc)
    d["krum_target"].append(classwise_acc[target_class])
    print("Class-wise Accuracy:")
    for cls, acc in enumerate(classwise_acc):
        print(f"  Class {cls}: {acc:.4f}")


[Round 1]
Test Accuracy: 0.4536 | Loss: 2.0886
Class-wise Accuracy:
  Class 0: 0.4060
  Class 1: 0.5510
  Class 2: 0.3350
  Class 3: 0.2640
  Class 4: 0.2540
  Class 5: 0.6900
  Class 6: 0.6960
  Class 7: 0.4150
  Class 8: 0.9010
  Class 9: 0.0240

[Round 2]
Test Accuracy: 0.5050 | Loss: 2.6867
Class-wise Accuracy:
  Class 0: 0.6380
  Class 1: 0.6650
  Class 2: 0.3580
  Class 3: 0.4060
  Class 4: 0.3200
  Class 5: 0.6320
  Class 6: 0.6200
  Class 7: 0.4730
  Class 8: 0.8550
  Class 9: 0.0830

[Round 3]
Test Accuracy: 0.4952 | Loss: 3.8315
Class-wise Accuracy:
  Class 0: 0.6010
  Class 1: 0.6350
  Class 2: 0.4340
  Class 3: 0.3600
  Class 4: 0.2910
  Class 5: 0.7150
  Class 6: 0.5570
  Class 7: 0.4660
  Class 8: 0.8650
  Class 9: 0.0280

[Round 4]
Test Accuracy: 0.5055 | Loss: 3.8296
Class-wise Accuracy:
  Class 0: 0.6460
  Class 1: 0.7330
  Class 2: 0.3600
  Class 3: 0.3290
  Class 4: 0.2860
  Class 5: 0.7310
  Class 6: 0.5630
  Class 7: 0.5280
  Class 8: 0.8370
  Class 9: 0.0420

[Ro

In [34]:
d

{'baseline_overall': [0.1305,
  0.5263,
  0.7298,
  0.7632,
  0.7823,
  0.7859,
  0.7905,
  0.7984,
  0.7965,
  0.7988,
  0.8025,
  0.8019,
  0.7986,
  0.7986,
  0.8003,
  0.8006,
  0.7982,
  0.7988,
  0.7981,
  0.7996,
  0.7998,
  0.7988,
  0.8003,
  0.7998,
  0.8009,
  0.799,
  0.7991,
  0.7976,
  0.8001,
  0.7997],
 'baseline_target': [0.0,
  0.483,
  0.727,
  0.759,
  0.786,
  0.797,
  0.801,
  0.818,
  0.794,
  0.817,
  0.82,
  0.83,
  0.838,
  0.823,
  0.835,
  0.828,
  0.824,
  0.815,
  0.804,
  0.835,
  0.824,
  0.824,
  0.844,
  0.841,
  0.839,
  0.834,
  0.834,
  0.835,
  0.83,
  0.834],
 'attack_overall': [0.114,
  0.2885,
  0.5878,
  0.6437,
  0.6603,
  0.6601,
  0.6629,
  0.6709,
  0.6728,
  0.6735,
  0.6725,
  0.6696,
  0.6753,
  0.6729,
  0.6722,
  0.6745,
  0.6707,
  0.6746,
  0.6736,
  0.6705,
  0.6703,
  0.6695,
  0.6702,
  0.6732,
  0.6727,
  0.6742,
  0.6699,
  0.6717,
  0.6717,
  0.6712],
 'attack_target': [0.0,
  0.0,
  0.02,
  0.048,
  0.023,
  0.008,
  0.001,
  

In [41]:
def trimmed_mean_aggregate(weight_list, n_trim):
    n_clients = len(weight_list)

    all_keys = [set(w.keys()) for w in weight_list]
    common_keys = set.intersection(*all_keys)

    aggregated_weights = {}

    for key in common_keys:
        try:
            tensors = [client[key] for client in weight_list]
            stacked = torch.stack(tensors, dim=0)

            if not torch.is_floating_point(stacked):
                stacked = stacked.float()

            sorted_vals, _ = torch.sort(stacked, dim=0)
            trimmed_vals = sorted_vals[n_trim: n_clients - n_trim]
            aggregated_weights[key] = torch.mean(trimmed_vals, dim=0)

        except Exception as e:
            print(f"Skipping key '{key}' due to error: {e}")

    return aggregated_weights


global_model = models.resnet18(weights=None)
global_model.fc = nn.Linear(global_model.fc.in_features, 10)
global_model.to(device)

num_rounds = 30
num_clients = 5

# Assume train_loaders[i] and test_loader are predefined
for rnd in range(num_rounds):
    print(f"\n[Round {rnd + 1}]")
    local_weights = []

    for cid in range(num_clients):
        client_model = copy.deepcopy(global_model)

        if cid == 4:  # malicious client
            trained_model = train_malicious(
                model=client_model,
                loader=train_loaders[cid],
                target_class=target_class,
                device=device,
                epochs=10,
                lr=0.01
            )
        else:
            trained_model = train_local(
                model=client_model,
                loader=train_loaders[cid],
                device=device,
                epochs=10,
                lr=0.01
            )

        local_weights.append(trained_model.state_dict())

    # Krum aggregation
    global_weights = trimmed_mean_aggregate(local_weights,2)
    global_model.load_state_dict(global_weights)

    # Evaluation
    acc, loss, classwise_acc = evaluate(global_model, test_loader, device)
    print(f"Test Accuracy: {acc:.4f} | Loss: {loss:.4f}")
    d["krum_overall"].append(acc)
    d["krum_target"].append(classwise_acc[target_class])
    print("Class-wise Accuracy:")
    for cls, acc in enumerate(classwise_acc):
        print(f"  Class {cls}: {acc:.4f}")


[Round 1]
Test Accuracy: 0.1184 | Loss: 2.6202
Class-wise Accuracy:
  Class 0: 0.0000
  Class 1: 0.0000
  Class 2: 0.0690
  Class 3: 0.0780
  Class 4: 0.8930
  Class 5: 0.0190
  Class 6: 0.0670
  Class 7: 0.0000
  Class 8: 0.0580
  Class 9: 0.0000

[Round 2]
Test Accuracy: 0.3055 | Loss: 2.1221
Class-wise Accuracy:
  Class 0: 0.4030
  Class 1: 0.0660
  Class 2: 0.1350
  Class 3: 0.3970
  Class 4: 0.7300
  Class 5: 0.2010
  Class 6: 0.5430
  Class 7: 0.1680
  Class 8: 0.3580
  Class 9: 0.0540

[Round 3]
Test Accuracy: 0.4962 | Loss: 1.4629
Class-wise Accuracy:
  Class 0: 0.5570
  Class 1: 0.4590
  Class 2: 0.2280
  Class 3: 0.4600
  Class 4: 0.5630
  Class 5: 0.3850
  Class 6: 0.7080
  Class 7: 0.4790
  Class 8: 0.8050
  Class 9: 0.3180

[Round 4]
Test Accuracy: 0.5721 | Loss: 1.3529
Class-wise Accuracy:
  Class 0: 0.6550
  Class 1: 0.6540
  Class 2: 0.2310
  Class 3: 0.4350
  Class 4: 0.6210
  Class 5: 0.5430
  Class 6: 0.7750
  Class 7: 0.5660
  Class 8: 0.7370
  Class 9: 0.5040

[Ro

In [6]:
import torch
import torch.nn.functional as F
import copy
def distill_knowledge(global_model, local_models, proxy_loader, device, distill_epochs=3, temperature=3.0):
    print(f"→ Starting distillation with T = {temperature}")
    global_model.train()
    optimizer = torch.optim.SGD(global_model.parameters(), lr=0.01)

    for epoch in range(distill_epochs):
        print(f"  [Distill Epoch {epoch+1}/{distill_epochs}]")
        for images, _ in proxy_loader:
            images = images.to(device)
            ensemble_logits = torch.zeros((images.size(0), 10), device=device)

            with torch.no_grad():
                for model in local_models:
                    model.eval()
                    logits = model(images)
                    ensemble_logits += F.softmax(logits / temperature, dim=1)

            ensemble_logits /= len(local_models)
            output = global_model(images)
            student_log_probs = F.log_softmax(output / temperature, dim=1)
            loss = F.kl_div(student_log_probs, ensemble_logits, reduction="batchmean") * (temperature ** 2)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
    return global_model


global_model = models.resnet18(weights=None)
device = "cuda" if torch.cuda.is_available() else "cpu"
global_model.fc = nn.Linear(global_model.fc.in_features, 10)
global_model.to(device)

num_rounds = 30
num_clients = 5

for rnd in range(num_rounds):
    print(f"\n[Round {rnd + 1}]")
    local_weights = []

    for cid in range(num_clients):
        client_model = copy.deepcopy(global_model)

        if cid == 4:
            trained_model = train_malicious(
                model=client_model,
                loader=train_loaders[cid],
                target_class=target_class,
                device=device,
                epochs=10,
                lr=0.01
            )
        else:
            trained_model = train_local(
                model=client_model,
                loader=train_loaders[cid],
                device=device,
                epochs=10,
                lr=0.01
            )

        local_weights.append(trained_model.state_dict())

    # Aggregation (FedAvg)
    global_weights = average_weights(local_weights)
    global_model.load_state_dict(global_weights)

    
    if rnd >= 10:
        proxy_loader = DataLoader(test_dataset, batch_size=64, shuffle=True)
        local_models = []
        for state in local_weights:
            local_model = models.resnet18(weights=None)
            local_model.fc = nn.Linear(global_model.fc.in_features, 10)
            local_model.to(device)
            local_model.load_state_dict(state)
            local_models.append(local_model)

        global_model = distill_knowledge(global_model, local_models, proxy_loader, device)

    # Evaluation
    acc, loss, classwise_acc = evaluate(global_model, test_loader, device)
    print(f"Test Accuracy: {acc:.4f} | Loss: {loss:.4f}")
    
    d["def_overall"].append(acc)
    d["def_target"].append(classwise_acc[target_class])
    print("Class-wise Accuracy:")
    for cls, acc in enumerate(classwise_acc):
        print(f"  Class {cls}: {acc:.4f}")



[Round 1]
Test Accuracy: 0.1278 | Loss: 2.5611
Class-wise Accuracy:
  Class 0: 0.0000
  Class 1: 0.0000
  Class 2: 0.0000
  Class 3: 0.4700
  Class 4: 0.7790
  Class 5: 0.0000
  Class 6: 0.0000
  Class 7: 0.0000
  Class 8: 0.0290
  Class 9: 0.0000

[Round 2]
Test Accuracy: 0.3766 | Loss: 1.9682
Class-wise Accuracy:
  Class 0: 0.0000
  Class 1: 0.2130
  Class 2: 0.1860
  Class 3: 0.7340
  Class 4: 0.7160
  Class 5: 0.0250
  Class 6: 0.4430
  Class 7: 0.2390
  Class 8: 0.7310
  Class 9: 0.4790

[Round 3]
Test Accuracy: 0.6085 | Loss: 1.4409
Class-wise Accuracy:
  Class 0: 0.0510
  Class 1: 0.7610
  Class 2: 0.4230
  Class 3: 0.5640
  Class 4: 0.7680
  Class 5: 0.4310
  Class 6: 0.7950
  Class 7: 0.6310
  Class 8: 0.8930
  Class 9: 0.7680

[Round 4]
Test Accuracy: 0.6441 | Loss: 1.6800
Class-wise Accuracy:
  Class 0: 0.0840
  Class 1: 0.8000
  Class 2: 0.4470
  Class 3: 0.5790
  Class 4: 0.7970
  Class 5: 0.5470
  Class 6: 0.7840
  Class 7: 0.7020
  Class 8: 0.9140
  Class 9: 0.7870

[Ro