In [1]:
!pip install -q flwr[simulation] flwr-datasets[vision] torch torchvision matplotlib

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m66.7/66.7 MB[0m [31m12.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m4.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.8/13.8 MB[0m [31m31.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.6/24.6 MB[0m [31m29.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m883.7/883.7 kB[0m [31m36.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m664.8/664.8 MB[0m [31m2.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m211.5/211.5 MB[0m [31m5.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m56.3/56.3 MB[0m [31m10.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [2]:
# Setup and Device Configuration
import torch
import flwr
from datasets.utils.logging import disable_progress_bar

In [3]:
# Set device and constants
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Training on {DEVICE}")
print(f"Flower {flwr.__version__} / PyTorch {torch.__version__}")
disable_progress_bar()

Training on cpu
Flower 1.18.0 / PyTorch 2.6.0+cu124


In [4]:
# Set to 5 clients (modified from 10)
NUM_CLIENTS = 5
BATCH_SIZE = 32

In [5]:
# Define Neural Network Model
import torch.nn as nn
import torch.nn.functional as F

In [6]:
class Net(nn.Module):
    def __init__(self) -> None:
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [7]:
# Data Loading Functions
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from flwr_datasets import FederatedDataset

In [8]:
def load_datasets(partition_id: int):
    """Load CIFAR-10 data partitions for federated learning."""
    fds = FederatedDataset(dataset="cifar10", partitioners={"train": NUM_CLIENTS})
    partition = fds.load_partition(partition_id)
    # Divide data on each node: 80% train, 20% test
    partition_train_test = partition.train_test_split(test_size=0.2, seed=42)
    pytorch_transforms = transforms.Compose(
        [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
    )

    def apply_transforms(batch):
        batch["img"] = [pytorch_transforms(img) for img in batch["img"]]
        return batch

    # Create train/val for each partition and wrap it into DataLoader
    partition_train_test = partition_train_test.with_transform(apply_transforms)
    trainloader = DataLoader(
        partition_train_test["train"], batch_size=BATCH_SIZE, shuffle=True
    )
    valloader = DataLoader(partition_train_test["test"], batch_size=BATCH_SIZE)
    testset = fds.load_split("test").with_transform(apply_transforms)
    testloader = DataLoader(testset, batch_size=BATCH_SIZE)
    return trainloader, valloader, testloader

In [9]:
# Training and Testing Functions
def train(net, trainloader, epochs: int, verbose=False):
    """Train the network on the training set."""
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(net.parameters())

    net.train()
    for epoch in range(epochs):
        correct, total, epoch_loss = 0, 0, 0.0

        for batch in trainloader:
            images, labels = batch["img"].to(DEVICE), batch["label"].to(DEVICE)

            optimizer.zero_grad()
            outputs = net(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            # Metrics
            epoch_loss += loss
            total += labels.size(0)
            correct += (torch.max(outputs.data, 1)[1] == labels).sum().item()

        epoch_loss /= len(trainloader.dataset)
        epoch_acc = correct / total

        if verbose:
            print(f"Epoch {epoch+1}: train loss {epoch_loss}, accuracy {epoch_acc}")

In [10]:
def test(net, testloader):
    """Evaluate the network on the entire test set."""
    criterion = torch.nn.CrossEntropyLoss()
    correct, total, loss = 0, 0, 0.0
    net.eval()
    with torch.no_grad():
        for batch in testloader:
            images, labels = batch["img"].to(DEVICE), batch["label"].to(DEVICE)
            outputs = net(images)
            loss += criterion(outputs, labels).item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    loss /= len(testloader.dataset)
    accuracy = correct / total
    return loss, accuracy

In [11]:
# Parameter Handling Utilities
from collections import OrderedDict
from typing import List
import numpy as np

In [12]:
def set_parameters(net, parameters: List[np.ndarray]):
    """Set model parameters from a list of NumPy arrays."""
    params_dict = zip(net.state_dict().keys(), parameters)
    state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
    net.load_state_dict(state_dict, strict=True)

def get_parameters(net) -> List[np.ndarray]:
    """Get model parameters as a list of NumPy arrays."""
    return [val.cpu().numpy() for _, val in net.state_dict().items()]

In [13]:
# Flower Client Definition
from flwr.client import Client, ClientApp, NumPyClient
from flwr.common import Context

class FlowerClient(NumPyClient):
    def __init__(self, net, trainloader, valloader):
        self.net = net
        self.trainloader = trainloader
        self.valloader = valloader

    def get_parameters(self, config):
        return get_parameters(self.net)

    def fit(self, parameters, config):
        set_parameters(self.net, parameters)
        train(self.net, self.trainloader, epochs=1)
        return get_parameters(self.net), len(self.trainloader), {}

    def evaluate(self, parameters, config):
        set_parameters(self.net, parameters)
        loss, accuracy = test(self.net, self.valloader)
        return float(loss), len(self.valloader), {"accuracy": float(accuracy)}

In [14]:
# Client Factory Function
def client_fn(context: Context) -> Client:
    """Create a Flower client representing a single organization."""
    # Load model
    net = Net().to(DEVICE)

    # Load data partition associated with this client
    partition_id = context.node_config["partition-id"]
    trainloader, valloader, _ = load_datasets(partition_id=partition_id)

    # Create and return client
    return FlowerClient(net, trainloader, valloader).to_client()

In [15]:
# Federated Learning Strategy
from flwr.server.strategy import FedAvg

# Create federated learning strategy (modified for 5 clients)
strategy = FedAvg(
    fraction_fit=1.0,      # Sample 100% of available clients for training
    fraction_evaluate=1.0, # Sample 100% of available clients for evaluation
    min_fit_clients=5,     # Never sample less than 5 clients for training
    min_evaluate_clients=5, # Never sample less than 5 clients for evaluation
    min_available_clients=5, # Wait until all 5 clients are available
)

In [16]:
# Server Configuration
from flwr.server import ServerApp, ServerConfig, ServerAppComponents

# Define server function
def server_fn(context: Context) -> ServerAppComponents:
    """Configure server components."""
    # Configure the server for 5 rounds of training
    config = ServerConfig(num_rounds=5)
    return ServerAppComponents(strategy=strategy, config=config)

In [17]:
# Create client and server apps
client = ClientApp(client_fn=client_fn)
server = ServerApp(server_fn=server_fn)

In [18]:
# Resource Configuration and Simulation
from flwr.simulation import run_simulation

# Specify client resources
backend_config = {"client_resources": {"num_cpus": 1, "num_gpus": 0.0}}

# Use GPU if available
if DEVICE == "cuda":
    backend_config = {"client_resources": {"num_cpus": 1, "num_gpus": 1.0}}

In [19]:
# Run the federated learning simulation
run_simulation(
    server_app=server,
    client_app=client,
    num_supernodes=NUM_CLIENTS,  # 5 clients
    backend_config=backend_config,
    verbose_logging=True,
)

[94mDEBUG 2025-05-03 13:48:21,968[0m:     Asyncio event loop already running.
DEBUG:flwr:Asyncio event loop already running.
[94mDEBUG 2025-05-03 13:48:21,970[0m:     Logger propagate set to False
[94mDEBUG 2025-05-03 13:48:21,970[0m:     Pre-registering run with id 16603539432482731006
[94mDEBUG 2025-05-03 13:48:21,975[0m:     Using InMemoryState
[94mDEBUG 2025-05-03 13:48:21,976[0m:     Using InMemoryState
[92mINFO 2025-05-03 13:48:21,978[0m:      Starting Flower ServerApp, config: num_rounds=5, no round_timeout
[92mINFO 2025-05-03 13:48:21,981[0m:      
[94mDEBUG 2025-05-03 13:48:21,983[0m:     Using InMemoryState
[94mDEBUG 2025-05-03 13:48:21,984[0m:     Registered 5 nodes
[94mDEBUG 2025-05-03 13:48:21,984[0m:     Supported backends: ['ray']
[94mDEBUG 2025-05-03 13:48:21,987[0m:     Initialising: RayBackend
[94mDEBUG 2025-05-03 13:48:21,990[0m:     Backend config: {'client_resources': {'num_cpus': 1, 'num_gpus': 0.0}, 'init_args': {}, 'actor': {'tensorflow':