In [131]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset
import numpy as np
from collections import defaultdict

In [132]:
seed = 0
device = torch.device("cuda:4" if torch.cuda.is_available() else "cpu")
np.random.seed(seed)
torch.manual_seed(seed)

<torch._C.Generator at 0x7fd80002abd0>

In [133]:
class SimpleCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 16, 3, padding=1)
        self.fc = nn.Linear(16 * 28 * 28, 10)

    def forward(self, x):
        x = torch.relu(self.conv1(x))
        x = x.view(-1, 16 * 28 * 28)
        return self.fc(x)

In [134]:
def create_fixed_class_client_loaders(num_clients=20, k=2, batch_size=32):  #checked
    
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])

    dataset = datasets.MNIST(root="./data", train=True, download=True, transform=transform)

    # Build class indices
    class_indices = defaultdict(list)
    for idx, (_, label) in enumerate(dataset):
        class_indices[label].append(idx)
    
    for c in class_indices:
        np.random.shuffle(class_indices[c])

    # Assign exactly k unique classes to each client
    all_classes = np.arange(10)
    client_classes = [np.random.choice(all_classes, size=k, replace=False) for _ in range(num_clients)]

    client_indices = [[] for _ in range(num_clients)]

    # Distribute samples of each class to clients that need it
    for class_id in range(10):
    
        # Clients that selected this class
        clients_with_class = [cid for cid, classes in enumerate(client_classes) if class_id in classes]
        if not clients_with_class:
            continue

        # Split class data among those clients
        splits = np.array_split(class_indices[class_id], len(clients_with_class))
        for cid, split in zip(clients_with_class, splits):
            client_indices[cid].extend(split.tolist())

    client_loaders = []
    for indices in client_indices:
        if not indices:
            indices = [0] 
        loader = DataLoader(
            Subset(dataset, indices),
            batch_size=batch_size,
            shuffle=True,
            num_workers=2,
            pin_memory=True,
            drop_last=True
        )
        client_loaders.append(loader)
    return client_loaders, client_classes


In [135]:
def train_local(model, loader, optimizer, device, epochs=1, r=0): #checked
    model.train()
    criterion = nn.CrossEntropyLoss()
    lr = 0.01
    if r > 0: 
        lr = 0.001
    optimizer = optim.SGD(model.parameters(), lr=lr)
    for _ in range(epochs):
        for j, (data, target) in enumerate(loader):
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
            if r > 0:
                print(loss.item())
            if r > 0 and j > 10:
                return

In [136]:
def train_server(server_model, avg_logits, common_data, optimizer, epochs=1): #checked
    server_model.train()
    criterion = nn.KLDivLoss(reduction="batchmean")

    for _ in range(epochs):
        for x, y in zip(common_data, avg_logits):
            optimizer.zero_grad()
            log_prob = F.log_softmax(server_model(x), dim=-1)
            loss = criterion(log_prob, y)
            loss.backward()
            optimizer.step()

In [137]:
def evaluate(model, test_loader, device): #checked
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()
            total += target.size(0)
    return correct / total

In [138]:
def zero_out_uncertain_logits(logits_list, threshold=1.2):  #checked
    stacked = torch.stack(logits_list)        # [M, B, C]
    # Compute entropy per model per sample
    probs = F.softmax(stacked, dim=-1)        # [M, B, C]
    entropy = -(probs * probs.log()).sum(dim=-1)   # [M, B]

    # Mask: True if confident
    mask = (entropy < threshold).float()      # [M, B]

    # Expand mask to logits shape
    mask_expanded = mask.unsqueeze(-1)        # [M, B, 1]

    # Zero out uncertain logits
    masked_logits = stacked * mask_expanded   # [M, B, C]

    # Count how many models contributed per sample
    denom = mask.sum(dim=0).unsqueeze(-1).clamp(min=1)  # [B, 1]

    # Average only over confident models
    avg_logits = masked_logits.sum(dim=0) / denom       # [B, C]

    masked_probs = probs * mask_expanded
    avg_probs = masked_probs.sum(dim=0) / denom
    return avg_probs, torch.mean(entropy)


In [None]:
num_clients = 30
batch_size = 128
common_data_size = 512
k = 2
threshold = 3

client_loaders, _ = create_fixed_class_client_loaders(num_clients, batch_size=batch_size, k=k)

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True, drop_last=True)

local_models = [SimpleCNN().to(device) for _ in range(num_clients)]
server_model = SimpleCNN().to(device)

local_optimizers = [optim.SGD(m.parameters(), lr=0.01) for m in local_models]
server_optimizer = optim.SGD(server_model.parameters(), lr=0.01)


In [140]:
for r in range(5):
    local_acc = []
    for i, (model, loader, opt) in enumerate(zip(local_models, client_loaders, local_optimizers)):
        train_local(model, loader, opt, device, epochs=1, r=r)
        if i % 5 == 0:
            local_acc.append(evaluate(model, test_loader, device))

    common_data = []
    avg_logits = []
    entropies = []
    for x, _ in test_loader:
    #for _ in range(common_data_size):
        #x = torch.randn(batch_size, 1, 28, 28, device=device)
        x = x.to(device)
        common_data.append(x)
        
        with torch.no_grad():
            local_logits = [m(x).detach() for m in local_models]
            avg, entropy = zero_out_uncertain_logits(local_logits, threshold=threshold)
            entropies.append(entropy.cpu())
            avg_logits.append(avg)
    # if threshold > 0.7 and r % 5 == 0:
    #     threshold -= 0.1
            
    common_data = torch.stack(common_data)
    train_server(server_model, avg_logits, common_data, server_optimizer, epochs=1)
    for model in local_models:
        model.load_state_dict(server_model.state_dict())

    acc = evaluate(server_model, test_loader, device)
    print(f"{r + 1}: avg_entr: {np.mean(entropies):.2f}, Server model accuracy on test set: {acc*100:.2f}%, local models average accuracy before copying: {np.mean(local_acc)*100:.2f}%")

1: avg_entr: 0.42, Server model accuracy on test set: 37.75%, local models average accuracy before copying: 19.71%
1.2362573146820068
0.9225432872772217
0.7301048636436462
0.5863924622535706
0.5431197285652161
0.5130082964897156
0.42344745993614197
0.40522515773773193
0.3930840790271759
1.5968066453933716
1.158893346786499
0.9110904335975647
0.8214794993400574
0.6681274771690369
0.5831274390220642
0.572364330291748
0.52430260181427
0.4713136851787567
0.4134926497936249
0.43556955456733704
0.38313308358192444
2.633017063140869
1.8811463117599487
1.3504353761672974
0.9739511609077454
0.7587791085243225
0.638125479221344
0.597858726978302
0.5040568113327026
0.4169284701347351
0.4173172116279602
0.3534175455570221
0.34538179636001587
1.8757681846618652
1.3602550029754639
1.0123118162155151
0.8165843486785889
0.6326227784156799
0.5468258261680603
0.4753566384315491
0.4403635859489441
0.3811032474040985
0.336142361164093
0.31579089164733887
0.2777915298938751
1.5144028663635254
1.08231687545