In [None]:
pip install syft==0.7.0 torch torchvision


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from syft.frameworks.torch.fl import VirtualWorker
from syft import TorchHook

# Hook PyTorch to enable Federated Learning
hook = TorchHook(torch)

# Simulate two clients
client1 = VirtualWorker(hook, id="client1")
client2 = VirtualWorker(hook, id="client2")


In [None]:
# Define a simple neural network
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(28 * 28, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = x.view(-1, 28 * 28)
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# Initialize the model
model = Net()


In [None]:
# Transform and load the MNIST dataset
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])

# Split dataset between clients
mnist_train = datasets.MNIST('./data', train=True, download=True, transform=transform)
mnist_train_client1, mnist_train_client2 = torch.utils.data.random_split(mnist_train, [30000, 30000])

# Send datasets to the clients
train_loader_client1 = torch.utils.data.DataLoader(mnist_train_client1, batch_size=64, shuffle=True)
train_loader_client2 = torch.utils.data.DataLoader(mnist_train_client2, batch_size=64, shuffle=True)


In [None]:
# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()

def train_on_client(model, train_loader, optimizer):
    model.train()
    total_loss = 0
    for data, target in train_loader:
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(train_loader)

# Train locally on client1 and client2
optimizer_client1 = optim.SGD(model.parameters(), lr=0.01)
optimizer_client2 = optim.SGD(model.parameters(), lr=0.01)

for epoch in range(1, 3):  # Two epochs for demonstration
    loss_client1 = train_on_client(model, train_loader_client1, optimizer_client1)
    loss_client2 = train_on_client(model, train_loader_client2, optimizer_client2)
    print(f"Epoch {epoch}, Loss on Client 1: {loss_client1:.4f}, Loss on Client 2: {loss_client2:.4f}")


In [None]:
def federated_aggregation(global_model, client_models):
    # Average weights of client models
    global_state_dict = global_model.state_dict()
    for key in global_state_dict:
        global_state_dict[key] = torch.stack([client.state_dict()[key] for client in client_models]).mean(dim=0)
    global_model.load_state_dict(global_state_dict)

# Create copies of the model for clients
client_model1 = Net()
client_model2 = Net()

client_model1.load_state_dict(model.state_dict())
client_model2.load_state_dict(model.state_dict())

# Perform aggregation
federated_aggregation(model, [client_model1, client_model2])


In [None]:
# Load the test dataset
test_loader = torch.utils.data.DataLoader(datasets.MNIST('./data', train=False, transform=transform), batch_size=64, shuffle=False)

# Evaluate the global model
def evaluate(model, test_loader):
    model.eval()
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            output = model(data)
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()
    return correct / len(test_loader.dataset)

accuracy = evaluate(model, test_loader)
print(f"Global Model Test Accuracy: {accuracy:.4f}")
