In [33]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset
import numpy as np
from collections import defaultdict

In [34]:
seed = 0
device = torch.device("cuda:4" if torch.cuda.is_available() else "cpu")
np.random.seed(seed)
torch.manual_seed(seed)

<torch._C.Generator at 0x7f57840f8190>

In [35]:
class SimpleCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 16, 3, padding=1)
        self.fc = nn.Linear(16 * 28 * 28, 10)

    def forward(self, x):
        x = torch.relu(self.conv1(x))
        x = x.view(-1, 16 * 28 * 28)
        return self.fc(x)

In [36]:
def create_fixed_class_client_loaders(num_clients=20, k=2, batch_size=32):  #checked
    
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])

    dataset = datasets.MNIST(root="./data", train=True, download=True, transform=transform)

    # Build class indices
    class_indices = defaultdict(list)
    for idx, (_, label) in enumerate(dataset):
        class_indices[label].append(idx)
    
    for c in class_indices:
        np.random.shuffle(class_indices[c])

    # Assign exactly k unique classes to each client
    all_classes = np.arange(10)
    client_classes = [np.random.choice(all_classes, size=k, replace=False) for _ in range(num_clients)]

    client_indices = [[] for _ in range(num_clients)]

    # Distribute samples of each class to clients that need it
    for class_id in range(10):
    
        # Clients that selected this class
        clients_with_class = [cid for cid, classes in enumerate(client_classes) if class_id in classes]
        if not clients_with_class:
            continue

        # Split class data among those clients
        splits = np.array_split(class_indices[class_id], len(clients_with_class))
        for cid, split in zip(clients_with_class, splits):
            client_indices[cid].extend(split.tolist())

    client_loaders = []
    for indices in client_indices:
        if not indices:
            indices = [0] 
        loader = DataLoader(
            Subset(dataset, indices),
            batch_size=batch_size,
            shuffle=True,
            num_workers=2,
            pin_memory=True,
            drop_last=True
        )
        client_loaders.append(loader)
    return client_loaders, client_classes


In [37]:
def train_local(model, loader, optimizer, device, epochs=1): #checked
    model.train()
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=0.01)
    for _ in range(epochs):
        for data, target in loader:
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()

In [38]:
def train_server(server_model, avg_logits, common_data, optimizer, epochs=1): #checked
    server_model.train()
    criterion = nn.KLDivLoss(reduction="batchmean")

    for _ in range(epochs):
        for x, y in zip(common_data, avg_logits):
            optimizer.zero_grad()
            log_prob = F.log_softmax(server_model(x), dim=-1)
            loss = criterion(log_prob, F.softmax(y, dim=-1))
            loss.backward()
            optimizer.step()

In [39]:
def evaluate(model, test_loader, device): #checked
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()
            total += target.size(0)
    return correct / total

In [40]:
def zero_out_uncertain_logits(logits_list, threshold=1.2):  #checked
    stacked = torch.stack(logits_list)        # [M, B, C]
    # Compute entropy per model per sample
    probs = F.softmax(stacked, dim=-1)        # [M, B, C]
    entropy = -(probs * probs.log()).sum(dim=-1)   # [M, B]

    # Mask: True if confident
    mask = (entropy < threshold).float()      # [M, B]

    # Expand mask to logits shape
    mask_expanded = mask.unsqueeze(-1)        # [M, B, 1]

    # Zero out uncertain logits
    masked_logits = stacked * mask_expanded   # [M, B, C]

    # Count how many models contributed per sample
    denom = mask.sum(dim=0).unsqueeze(-1).clamp(min=1)  # [B, 1]

    # Average only over confident models
    avg_logits = masked_logits.sum(dim=0) / denom       # [B, C]
    return avg_logits


In [41]:
num_clients = 30
batch_size = 128
common_data_size = 512
k = 2
threshold = 3

client_loaders, _ = create_fixed_class_client_loaders(num_clients, batch_size=batch_size, k=k)

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=2, pin_memory=True, drop_last=True)

local_models = [SimpleCNN().to(device) for _ in range(num_clients)]
server_model = SimpleCNN().to(device)

local_optimizers = [optim.Adam(m.parameters(), lr=0.001) for m in local_models]
server_optimizer = optim.SGD(server_model.parameters(), lr=0.01)


In [42]:
for r in range(50):
    for i, (model, loader, opt) in enumerate(zip(local_models, client_loaders, local_optimizers)):
        train_local(model, loader, opt, device, epochs=1)

    common_data = []
    avg_logits = []

    for _ in range(common_data_size):
        x = torch.randn(batch_size, 1, 28, 28, device=device)
        common_data.append(x)
        with torch.no_grad():
            local_logits = [m(x).detach() for m in local_models]
            avg_logits.append(zero_out_uncertain_logits(local_logits, threshold=threshold))
            
    common_data = torch.stack(common_data)
    train_server(server_model, avg_logits, common_data, server_optimizer, epochs=1)
    for model in local_models:
        model.load_state_dict(server_model.state_dict())

    acc = evaluate(server_model, test_loader, device)
    print(f"{r + 1}: Server model accuracy on test set: {acc*100:.2f}%\n")

1: Server model accuracy on test set: 26.61%

2: Server model accuracy on test set: 43.93%

3: Server model accuracy on test set: 44.80%

4: Server model accuracy on test set: 53.53%

5: Server model accuracy on test set: 53.55%

6: Server model accuracy on test set: 57.79%

7: Server model accuracy on test set: 59.34%

8: Server model accuracy on test set: 63.27%

9: Server model accuracy on test set: 65.37%

10: Server model accuracy on test set: 67.08%

11: Server model accuracy on test set: 68.83%

12: Server model accuracy on test set: 70.28%

13: Server model accuracy on test set: 71.35%

14: Server model accuracy on test set: 72.42%

15: Server model accuracy on test set: 72.97%

16: Server model accuracy on test set: 72.91%

17: Server model accuracy on test set: 74.13%

18: Server model accuracy on test set: 75.29%

19: Server model accuracy on test set: 75.53%

20: Server model accuracy on test set: 75.51%

21: Server model accuracy on test set: 76.97%

22: Server model accur