In [10]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset
import numpy as np
from collections import defaultdict

In [11]:
device = torch.device("cuda:4" if torch.cuda.is_available() else "cpu")

In [12]:
class SimpleCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 16, 3, padding=1)
        self.fc = nn.Linear(16 * 28 * 28, 10)

    def forward(self, x):
        x = torch.relu(self.conv1(x))
        x = x.view(-1, 16 * 28 * 28)
        return self.fc(x)

In [13]:
def create_probabilistic_client_loaders(num_clients=20, p=0.7, batch_size=32):
    """
    Create client loaders for MNIST:
    - First take out a common dataset of size `common_size`.
    - Each class is assigned to each client with probability p.
    - Class samples are split among selected clients (no overlap).
    """
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])

    dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
    # Build class indices from remaining dataset
    class_indices = defaultdict(list)
    for idx, (_, label) in enumerate(dataset):
        class_indices[label].append(idx)
    for c in class_indices:
        np.random.shuffle(class_indices[c])

    client_indices = [[] for _ in range(num_clients)]

    for class_idx in range(10):
        # Determine which clients get this class
        client_mask = np.random.rand(num_clients) < p
        selected_clients = np.where(client_mask)[0]

        if len(selected_clients) > 0:
            # Split class indices among selected clients
            splits = np.array_split(class_indices[class_idx], len(selected_clients))
            for client_id, split in zip(selected_clients, splits):
                client_indices[client_id].extend(split)

    # Create DataLoaders for clients
    client_loaders = []
    for indices in client_indices:
        if not indices:
            indices = [0]  # fallback if client gets no data
        loader = DataLoader(
            Subset(dataset, indices),
            batch_size=batch_size,
            shuffle=True,
            num_workers=2,
            pin_memory=True,
            drop_last=True
        )
        client_loaders.append(loader)

    return client_loaders


In [14]:
def train_local_data(model, loader, optimizer, device, epochs=1):
    model.train()
    criterion = nn.CrossEntropyLoss()
    if not optimizer:
        optimizer = optim.Adam(model.parameters(), lr=0.0001)
    for _ in range(epochs):
        for data, target in loader:
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()

In [15]:
def train_common_data(model, avg_logits, common_data, optimizer, epochs=1):
    model.train()
    criterion = nn.MSELoss()
    if not optimizer:
        optimizer = optim.Adam(model.parameters(), lr=0.0001)
    for _ in range(epochs):
        for x, y in zip(common_data, avg_logits):
            optimizer.zero_grad()
            server_logits = model(x)
            loss = criterion(server_logits, y)
            loss.backward()
            optimizer.step()

In [16]:
def evaluate(model, test_loader, device):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()
            total += target.size(0)
    return correct / total

In [17]:
num_clients = 10
batch_size = 128
common_data_size = 512
client_loaders = create_probabilistic_client_loaders(num_clients, batch_size=batch_size, p=0.2)

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=2, pin_memory=True, drop_last=True)

local_data_models = [SimpleCNN().to(device) for _ in range(num_clients)]
common_data_models = [SimpleCNN().to(device) for _ in range(num_clients)]

optimizers = [optim.Adam(m.parameters(), lr=0.0001) for m in local_data_models]

In [None]:
for r in range(50):
    for i, (model, loader, opt) in enumerate(zip(local_data_models, client_loaders, optimizers)):
        train_local_data(model, loader, None, device, epochs=1)

    common_data = []
    avg_logits = []

    for _ in range(common_data_size):
        x = torch.randn(batch_size, 1, 28, 28, device=device)
        common_data.append(x)
        with torch.no_grad():
            local_logits = [m(x).detach() for m in local_data_models]
            avg_logits.append(torch.mean(torch.stack(local_logits), dim=0))
    common_data = torch.stack(common_data)

    accs = []
    for m, opt in zip(common_data_models, optimizers):
        train_common_data(m, avg_logits, common_data, None, epochs=1)
        acc = evaluate(m, test_loader, device)
    
    for local_data_model, common_data_model in zip(local_data_models, common_data_models):
        local_data_model.load_state_dict(common_data_model.state_dict())

    print(f"{r}: Server model accuracy on test set: {np.mean(acc)*100:.2f}%\n")

0: Server model accuracy on test set: 11.61%

1: Server model accuracy on test set: 16.11%

2: Server model accuracy on test set: 17.89%

3: Server model accuracy on test set: 18.51%

4: Server model accuracy on test set: 18.67%

5: Server model accuracy on test set: 18.98%

6: Server model accuracy on test set: 19.32%

7: Server model accuracy on test set: 19.67%

8: Server model accuracy on test set: 19.88%

9: Server model accuracy on test set: 20.29%

10: Server model accuracy on test set: 20.53%

11: Server model accuracy on test set: 20.80%

12: Server model accuracy on test set: 21.10%

13: Server model accuracy on test set: 21.45%

14: Server model accuracy on test set: 21.67%

15: Server model accuracy on test set: 21.91%

16: Server model accuracy on test set: 22.03%

17: Server model accuracy on test set: 22.23%

18: Server model accuracy on test set: 22.42%

