In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms

# Define a simple neural network for MNIST
class SimpleNN(nn.Module):
    def __init__(self):
        super(SimpleNN, self).__init__()
        self.fc1 = nn.Linear(28 * 28, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = x.view(-1, 28 * 28)  # Flatten the input
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# Load MNIST dataset
def load_data():
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
    ])

    train_dataset = datasets.MNIST(root="./data", train=True, download=True, transform=transform)
    test_dataset = datasets.MNIST(root="./data", train=False, download=True, transform=transform)
    return train_dataset, test_dataset

# Split data among clients
def split_data(train_dataset, num_clients):
    client_data_size = len(train_dataset) // num_clients
    lengths = [client_data_size] * num_clients
    if sum(lengths) < len(train_dataset):
        lengths[-1] += len(train_dataset) - sum(lengths)
    return random_split(train_dataset, lengths)

# Local training for each client
def local_train(model, data_loader, epochs, lr):
    model.train()
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=lr)

    for epoch in range(epochs):
        for images, labels in data_loader:
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

# Federated averaging
def federated_average(global_model, client_models):
    global_state_dict = global_model.state_dict()

    for key in global_state_dict.keys():
        global_state_dict[key] = torch.stack([client.state_dict()[key] for client in client_models], dim=0).mean(dim=0)

    global_model.load_state_dict(global_state_dict)

# Test the model (client or global)
def test_model(model, test_loader):
    model.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for images, labels in test_loader:
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total
    return accuracy

# Main function
def main():
    num_clients = 3
    local_epochs = 2
    global_epochs = 5
    lr = 0.01

    # Load and split data
    train_dataset, test_dataset = load_data()
    client_datasets = split_data(train_dataset, num_clients)
    test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

    # Initialize models
    global_model = SimpleNN()
    client_models = [SimpleNN() for _ in range(num_clients)]

    for epoch in range(global_epochs):
        print(f"Global Epoch {epoch+1}/{global_epochs}")

        client_accuracies = []

        # Local training
        for client_id, client_data in enumerate(client_datasets):
            print(f"  Client {client_id+1}")
            client_loader = DataLoader(client_data, batch_size=64, shuffle=True)
            client_model = client_models[client_id]
            client_model.load_state_dict(global_model.state_dict())  # Sync with global model
            local_train(client_model, client_loader, local_epochs, lr)

            # Test client model on its local data
            client_accuracy = test_model(client_model, client_loader)
            client_accuracies.append(client_accuracy)
            print(f"    Client {client_id+1} Accuracy: {client_accuracy:.2f}%")

        # Federated averaging
        federated_average(global_model, client_models)

        # Test global model
        global_accuracy = test_model(global_model, test_loader)
        print(f"  Global Model Accuracy: {global_accuracy:.2f}%")

if __name__ == "__main__":
    main()


Global Epoch 1/5
  Client 1
    Client 1 Accuracy: 87.64%
  Client 2
    Client 2 Accuracy: 87.68%
  Client 3
    Client 3 Accuracy: 86.75%
  Global Model Accuracy: 88.00%
Global Epoch 2/5
  Client 1
    Client 1 Accuracy: 89.61%
  Client 2
    Client 2 Accuracy: 89.39%
  Client 3
    Client 3 Accuracy: 89.14%
  Global Model Accuracy: 89.88%
Global Epoch 3/5
  Client 1
    Client 1 Accuracy: 90.28%
  Client 2
    Client 2 Accuracy: 90.19%
  Client 3
    Client 3 Accuracy: 90.17%
  Global Model Accuracy: 90.74%
Global Epoch 4/5
  Client 1
    Client 1 Accuracy: 91.02%
  Client 2
    Client 2 Accuracy: 91.03%
  Client 3
    Client 3 Accuracy: 91.05%
  Global Model Accuracy: 91.39%
Global Epoch 5/5
  Client 1
    Client 1 Accuracy: 91.65%
  Client 2
    Client 2 Accuracy: 91.52%
  Client 3
    Client 3 Accuracy: 91.53%
  Global Model Accuracy: 91.93%


In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms

# Define a simple neural network for MNIST
class SimpleNN(nn.Module):
    def __init__(self):
        super(SimpleNN, self).__init__()
        self.fc1 = nn.Linear(28 * 28, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = x.view(-1, 28 * 28)  # Flatten the input
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# Load MNIST dataset
def load_data():
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
    ])

    train_dataset = datasets.MNIST(root="./data", train=True, download=True, transform=transform)
    test_dataset = datasets.MNIST(root="./data", train=False, download=True, transform=transform)
    return train_dataset, test_dataset

# Split data among clients
def split_data(train_dataset, num_clients):
    client_data_size = len(train_dataset) // num_clients
    lengths = [client_data_size] * num_clients
    if sum(lengths) < len(train_dataset):
        lengths[-1] += len(train_dataset) - sum(lengths)
    return random_split(train_dataset, lengths)

# Local training for each client
def local_train(model, data_loader, epochs, lr):
    model.train()
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=lr)

    for epoch in range(epochs):
        for images, labels in data_loader:
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

# Federated averaging
def federated_average(global_model, client_models):
    global_state_dict = global_model.state_dict()

    for key in global_state_dict.keys():
        global_state_dict[key] = torch.stack([client.state_dict()[key] for client in client_models], dim=0).mean(dim=0)

    global_model.load_state_dict(global_state_dict)

# Test the global model
def test_model(model, test_loader):
    model.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for images, labels in test_loader:
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total
    print(f'Test Accuracy: {accuracy:.2f}%')

# Main function
def main():
    num_clients = 5
    local_epochs = 2
    global_epochs = 10
    lr = 0.01

    # Load and split data
    train_dataset, test_dataset = load_data()
    client_datasets = split_data(train_dataset, num_clients)
    test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

    # Initialize models
    global_model = SimpleNN()
    client_models = [SimpleNN() for _ in range(num_clients)]

    for epoch in range(global_epochs):
        print(f"Global Epoch {epoch+1}/{global_epochs}")

        # Local training
        for client_id, client_data in enumerate(client_datasets):
            print(f"  Client {client_id+1}")
            client_loader = DataLoader(client_data, batch_size=64, shuffle=True)
            client_model = client_models[client_id]
            client_model.load_state_dict(global_model.state_dict())  # Sync with global model
            local_train(client_model, client_loader, local_epochs, lr)

        # Federated averaging
        federated_average(global_model, client_models)

        # Test global model
        test_model(global_model, test_loader)

if __name__ == "__main__":
    main()


Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9.91M/9.91M [00:00<00:00, 16.0MB/s]


Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28.9k/28.9k [00:00<00:00, 490kB/s]


Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1.65M/1.65M [00:00<00:00, 4.41MB/s]


Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4.54k/4.54k [00:00<00:00, 3.57MB/s]


Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw

Global Epoch 1/10
  Client 1
  Client 2
  Client 3
  Client 4
  Client 5
Test Accuracy: 86.23%
Global Epoch 2/10
  Client 1
  Client 2
  Client 3
  Client 4
  Client 5
Test Accuracy: 88.95%
Global Epoch 3/10
  Client 1
  Client 2
  Client 3
  Client 4
  Client 5
Test Accuracy: 89.80%
Global Epoch 4/10
  Client 1
  Client 2
  Client 3
  Client 4
  Client 5
Test Accuracy: 90.49%
Global Epoch 5/10
  Client 1
  Client 2
  Client 3
  Client 4
  Client 5
Test Accuracy: 91.00%
Global Epoch 6/10
  Client 1
  Client 2


KeyboardInterrupt: 