In [10]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset
import numpy as np
from collections import defaultdict

In [11]:
seed = 0
device = torch.device("cuda:4" if torch.cuda.is_available() else "cpu")
np.random.seed(seed)
torch.manual_seed(seed)

<torch._C.Generator at 0x7fd4581ec050>

In [12]:
class SimpleCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 16, 3, padding=1)
        self.fc = nn.Linear(16 * 28 * 28, 10)

    def forward(self, x):
        x = torch.relu(self.conv1(x))
        x = x.view(-1, 16 * 28 * 28)
        return self.fc(x)

In [None]:
def create_fixed_class_client_loaders(num_clients=20, k=2, batch_size=32):
    """
    Create client loaders for MNIST with exactly k classes per client.

    Args:
        num_clients (int): number of clients (3 to 100).
        k (int): number of classes per client (1 to 10).
        batch_size (int): DataLoader batch size.

    Returns:
        list of DataLoaders, each corresponding to one client.
    """
    assert 1 <= k <= 10, "k must be between 1 and 10"
    assert num_clients >= 1, "num_clients must be at least 1"

    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])

    dataset = datasets.MNIST(root="./data", train=True, download=True, transform=transform)

    # Build class indices
    class_indices = defaultdict(list)
    for idx, (_, label) in enumerate(dataset):
        class_indices[label].append(idx)
    for c in class_indices:
        np.random.shuffle(class_indices[c])

    # Assign exactly k unique classes to each client
    all_classes = np.arange(10)
    client_classes = [np.random.choice(all_classes, size=k, replace=False) for _ in range(num_clients)]

    client_indices = [[] for _ in range(num_clients)]

    # Distribute samples of each class to clients that need it
    for class_id in range(10):
        # Clients that selected this class
        clients_with_class = [cid for cid, classes in enumerate(client_classes) if class_id in classes]
        if not clients_with_class:
            continue

        # Split class data among those clients
        splits = np.array_split(class_indices[class_id], len(clients_with_class))
        for cid, split in zip(clients_with_class, splits):
            client_indices[cid].extend(split.tolist())

    # Create DataLoaders
    client_loaders = []
    for indices in client_indices:
        if not indices:
            indices = [0]  # fallback in case a client ends up empty (shouldn’t happen)
        loader = DataLoader(
            Subset(dataset, indices),
            batch_size=batch_size,
            shuffle=True,
            num_workers=2,
            pin_memory=True,
            drop_last=True
        )
        client_loaders.append(loader)

    return client_loaders, client_classes


In [14]:
def train_local(model, loader, optimizer, device, epochs=1):
    model.train()
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=0.01)
    for _ in range(epochs):
        for data, target in loader:
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()

In [15]:
def train_server(server_model, avg_logits, common_data, optimizer, epochs=1):
    server_model.train()
    criterion = nn.MSELoss()

    for _ in range(epochs):
        for x, y in zip(common_data, avg_logits):
            optimizer.zero_grad()
            server_logits = server_model(x)
            loss = criterion(server_logits, y)
            loss.backward()
            optimizer.step()

In [16]:
def evaluate(model, test_loader, device):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()
            total += target.size(0)
    return correct / total

In [None]:
num_clients = 30
batch_size = 128
common_data_size = 512
p = 0.2
client_loaders, _ = create_fixed_class_client_loaders(num_clients, batch_size=batch_size, k=2)
    
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=2, pin_memory=True, drop_last=True)

local_models = [SimpleCNN().to(device) for _ in range(num_clients)]
server_model = SimpleCNN().to(device)

local_optimizers = [optim.Adam(m.parameters(), lr=0.001) for m in local_models]
server_optimizer = optim.SGD(server_model.parameters(), lr=0.01)


1190
1674
4954
2585
1577
2817
2118
1992
2116
1806
2116
1213
1213
1339
1887
1804
3018
2035
1649
1213
1190
1189
1189
3329
1725
4023
1886
2114
1701
1338
Client 0 has classes [6 8]
Client 1 has classes [9 3]
Client 2 has classes [0 2]
Client 3 has classes [4 1]
Client 4 has classes [8 7]
Client 5 has classes [5 4]
Client 6 has classes [6 4]
Client 7 has classes [8 4]
Client 8 has classes [1 9]
Client 9 has classes [1 3]
Client 10 has classes [1 9]
Client 11 has classes [8 3]
Client 12 has classes [8 3]
Client 13 has classes [3 6]
Client 14 has classes [5 8]
Client 15 has classes [3 1]
Client 16 has classes [0 7]
Client 17 has classes [7 9]
Client 18 has classes [6 9]
Client 19 has classes [8 3]
Client 20 has classes [6 8]
Client 21 has classes [8 6]
Client 22 has classes [6 8]
Client 23 has classes [0 5]
Client 24 has classes [3 7]
Client 25 has classes [7 2]
Client 26 has classes [8 5]
Client 27 has classes [1 9]
Client 28 has classes [7 6]
Client 29 has classes [6 3]


In [18]:
for r in range(50):
    for i, (model, loader, opt) in enumerate(zip(local_models, client_loaders, local_optimizers)):
        train_local(model, loader, opt, device, epochs=1)

    common_data = []
    avg_logits = []

    for _ in range(common_data_size):
        x = torch.randn(batch_size, 1, 28, 28, device=device)
        common_data.append(x)
        with torch.no_grad():
            local_logits = [m(x).detach() for m in local_models]
            avg_logits.append(torch.mean(torch.stack(local_logits), dim=0))
    common_data = torch.stack(common_data)
    train_server(server_model, avg_logits, common_data, server_optimizer, epochs=1)
    for model in local_models:
        model.load_state_dict(server_model.state_dict())

    acc = evaluate(server_model, test_loader, device)
    print(f"{r}: Server model accuracy on test set: {acc*100:.2f}%\n")

0: Server model accuracy on test set: 26.79%

1: Server model accuracy on test set: 49.95%

2: Server model accuracy on test set: 51.83%

3: Server model accuracy on test set: 56.74%

4: Server model accuracy on test set: 60.62%

5: Server model accuracy on test set: 64.79%

6: Server model accuracy on test set: 67.17%

7: Server model accuracy on test set: 69.19%

8: Server model accuracy on test set: 71.13%

9: Server model accuracy on test set: 72.38%

10: Server model accuracy on test set: 73.82%

11: Server model accuracy on test set: 74.64%

12: Server model accuracy on test set: 75.86%

13: Server model accuracy on test set: 76.77%

14: Server model accuracy on test set: 77.08%

15: Server model accuracy on test set: 77.60%

16: Server model accuracy on test set: 78.16%

17: Server model accuracy on test set: 79.53%

18: Server model accuracy on test set: 79.51%

19: Server model accuracy on test set: 80.03%

20: Server model accuracy on test set: 80.65%

21: Server model accura