In [19]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset
import numpy as np
from collections import defaultdict

In [20]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [21]:
class SimpleCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 16, 3, padding=1)
        self.fc = nn.Linear(16 * 28 * 28, 10)

    def forward(self, x):
        x = torch.relu(self.conv1(x))
        x = x.view(-1, 16 * 28 * 28)
        return self.fc(x)

In [22]:
def create_probabilistic_client_loaders(num_clients=20, p=0.7, batch_size=32):
    """
    Create client loaders for MNIST:
    - First take out a common dataset of size `common_size`.
    - Each class is assigned to each client with probability p.
    - Class samples are split among selected clients (no overlap).
    """
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])

    dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
    # Build class indices from remaining dataset
    class_indices = defaultdict(list)
    for idx, (_, label) in enumerate(dataset):
        class_indices[label].append(idx)
    for c in class_indices:
        np.random.shuffle(class_indices[c])

    client_indices = [[] for _ in range(num_clients)]

    for class_idx in range(10):
        # Determine which clients get this class
        client_mask = np.random.rand(num_clients) < p
        selected_clients = np.where(client_mask)[0]

        if len(selected_clients) > 0:
            # Split class indices among selected clients
            splits = np.array_split(class_indices[class_idx], len(selected_clients))
            for client_id, split in zip(selected_clients, splits):
                client_indices[client_id].extend(split)

    # Create DataLoaders for clients
    client_loaders = []
    for indices in client_indices:
        if not indices:
            indices = [0]  # fallback if client gets no data
        loader = DataLoader(
            Subset(dataset, indices),
            batch_size=batch_size,
            shuffle=True,
            num_workers=2,
            pin_memory=True,
            drop_last=True
        )
        client_loaders.append(loader)

    return client_loaders


In [23]:
def train_local(model, loader, optimizer, device, epochs=1):
    model.train()
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=0.01)
    for _ in range(epochs):
        for data, target in loader:
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()

In [24]:
def train_server(server_model, avg_logits, common_data, optimizer, epochs=1):
    server_model.train()
    criterion = nn.MSELoss()

    for _ in range(epochs):
        for x, y in zip(common_data, avg_logits):
            optimizer.zero_grad()
            server_logits = server_model(x)
            loss = criterion(server_logits, y)
            loss.backward()
            optimizer.step()

In [25]:
def evaluate(model, test_loader, device):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()
            total += target.size(0)
    return correct / total

In [26]:
num_clients = 30
batch_size = 128
common_data_size = 512
client_loaders = create_probabilistic_client_loaders(num_clients, batch_size=batch_size, p=0.2)

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=2, pin_memory=True, drop_last=True)

local_models = [SimpleCNN().to(device) for _ in range(num_clients)]
server_model = SimpleCNN().to(device)

local_optimizers = [optim.Adam(m.parameters(), lr=0.001) for m in local_models]
server_optimizer = optim.SGD(server_model.parameters(), lr=0.01)


In [27]:
for r in range(50):
    for i, (model, loader, opt) in enumerate(zip(local_models, client_loaders, local_optimizers)):
        train_local(model, loader, opt, device, epochs=1)

    common_data = []
    avg_logits = []

    for _ in range(common_data_size):
        x = torch.randn(batch_size, 1, 28, 28, device=device)
        common_data.append(x)
        with torch.no_grad():
            local_logits = [m(x).detach() for m in local_models]
            avg_logits.append(torch.mean(torch.stack(local_logits), dim=0))
    common_data = torch.stack(common_data)

    train_server(server_model, avg_logits, common_data, server_optimizer, epochs=1)
    for model in local_models:
        model.load_state_dict(server_model.state_dict())

    acc = evaluate(server_model, test_loader, device)
    print(f"Server model accuracy on test set: {acc*100:.2f}%\n")

Server model accuracy on test set: 10.83%

Server model accuracy on test set: 19.65%

Server model accuracy on test set: 28.62%

Server model accuracy on test set: 39.30%

Server model accuracy on test set: 45.78%

Server model accuracy on test set: 49.88%

Server model accuracy on test set: 53.50%

Server model accuracy on test set: 56.58%

Server model accuracy on test set: 59.67%

Server model accuracy on test set: 62.31%

Server model accuracy on test set: 64.54%

Server model accuracy on test set: 66.81%

Server model accuracy on test set: 68.78%

Server model accuracy on test set: 70.19%

Server model accuracy on test set: 71.59%

Server model accuracy on test set: 72.64%

Server model accuracy on test set: 74.05%

Server model accuracy on test set: 75.03%

Server model accuracy on test set: 76.42%

Server model accuracy on test set: 77.35%

Server model accuracy on test set: 78.16%

Server model accuracy on test set: 78.95%

Server model accuracy on test set: 79.92%

Server mode