In [1]:
!pip install -q flwr[simulation] flwr-datasets[vision] torch torchvision

In [4]:

"""
Hierarchical Federated Learning with Flower (HierFAVG) with Pushback
Extends previous script to loop through global rounds, edge aggregations, cloud aggregation,
then broadcast global model back to clients for additional local training.
Based on "Client-Edge-Cloud Hierarchical Federated Learning" by Liu et al.
"""

import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

import flwr as fl
from flwr.client import NumPyClient, Client
from flwr.server import ServerConfig, ServerApp, ServerAppComponents
from flwr.server.strategy import FedAvg
from flwr.simulation import run_simulation
from flwr_datasets import FederatedDataset
from flwr.common import ndarrays_to_parameters, NDArrays, Context

# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Hyperparameters
NUM_EDGE_SERVERS = 2
CLIENTS_PER_EDGE = 5
TOTAL_CLIENTS = NUM_EDGE_SERVERS * CLIENTS_PER_EDGE
BATCH_SIZE = 32
EDGE_ROUNDS = 3      # local rounds per edge
GLOBAL_ROUNDS = 2    # number of cloud aggregation rounds
LOCAL_EPOCHS_PUSH = 1  # epochs during pushback

# Directory to save models
MODEL_DIR = "models"
os.makedirs(MODEL_DIR, exist_ok=True)

#----------------------------------------------------------------
# Model Definition
#----------------------------------------------------------------
class Net(nn.Module):
    def __init__(self) -> None:
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return self.fc3(x)

#----------------------------------------------------------------
# Utility: get/set parameters
#----------------------------------------------------------------
def get_parameters(model: nn.Module) -> NDArrays:
    return [val.cpu().numpy() for _, val in model.state_dict().items()]

def set_parameters(model: nn.Module, params: NDArrays) -> None:
    state_dict = {k: torch.tensor(v) for k, v in zip(model.state_dict().keys(), params)}
    model.load_state_dict(state_dict, strict=True)

#----------------------------------------------------------------
# Data Loading
def load_datasets(partition_id: int, num_partitions: int):
    fds = FederatedDataset(dataset="uoft-cs/cifar10", partitioners={"train": num_partitions})
    part = fds.load_partition(partition_id)
    part_train_test = part.train_test_split(test_size=0.2, seed=42)

    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5)),
    ])
    def apply_tf(batch):
        batch["img"] = [transform(img) for img in batch["img"]]
        return batch

    part_train_test = part_train_test.with_transform(apply_tf)
    trainloader = DataLoader(part_train_test['train'], batch_size=BATCH_SIZE, shuffle=True)
    valloader  = DataLoader(part_train_test['test'],  batch_size=BATCH_SIZE)
    testset = fds.load_split('test').with_transform(apply_tf)
    testloader = DataLoader(testset, batch_size=BATCH_SIZE)
    return trainloader, valloader, testloader

#----------------------------------------------------------------
# Flower Client Implementation
#----------------------------------------------------------------
class FlowerClient(NumPyClient):
    def __init__(self, cid: int, net: Net, trainloader, valloader):
        self.cid = cid
        self.net = net.to(device)
        self.trainloader = trainloader
        self.valloader = valloader

    def get_parameters(self, config):
        return get_parameters(self.net)

    def fit(self, parameters, config):
        set_parameters(self.net, parameters)
        self.net.train()
        criterion = nn.CrossEntropyLoss()
        optimizer = torch.optim.Adam(self.net.parameters(), lr=0.01)
        for _ in range(config.get("local_epochs", 1)):
            for batch in self.trainloader:
                imgs, labels = batch['img'].to(device), batch['label'].to(device)
                optimizer.zero_grad()
                outputs = self.net(imgs)
                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()
        return get_parameters(self.net), len(self.trainloader.dataset), {}

    def evaluate(self, parameters, config):
        set_parameters(self.net, parameters)
        self.net.eval()
        criterion = nn.CrossEntropyLoss()
        loss, correct = 0.0, 0
        with torch.no_grad():
            for batch in self.valloader:
                imgs, labels = batch['img'].to(device), batch['label'].to(device)
                outputs = self.net(imgs)
                loss += criterion(outputs, labels).item()
                _, pred = torch.max(outputs, 1)
                correct += (pred == labels).sum().item()
        return loss / len(self.valloader), len(self.valloader.dataset), {"accuracy": correct / len(self.valloader.dataset)}

#----------------------------------------------------------------
# Simulate one edge server given initial global parameters
#----------------------------------------------------------------
def simulate_edge(edge_id: int, client_ids: list, initial_params: NDArrays) -> NDArrays:
    # client_fn maps local index to global partition
    def client_fn(ctx: Context) -> Client:
        local_pid = ctx.node_config["partition-id"]
        global_pid = client_ids[local_pid]
        net = Net()
        trainloader, valloader, _ = load_datasets(global_pid, TOTAL_CLIENTS)
        return FlowerClient(global_pid, net, trainloader, valloader).to_client()

    # Save edge model on final round via evaluate_fn
    def evaluate_fn(server_round: int, parameters: NDArrays, config: dict):
        if server_round == EDGE_ROUNDS:
            edge_model = Net().to(device)
            set_parameters(edge_model, parameters)
            torch.save(edge_model.state_dict(), os.path.join(MODEL_DIR, f"edge_{edge_id}_round_{server_round}.pth"))
        return None

    # Define server strategy
    def server_fn(ctx: Context) -> ServerAppComponents:
        strategy = FedAvg(
            fraction_fit=1.0,
            fraction_evaluate=0.0,
            min_fit_clients=len(client_ids),
            min_available_clients=len(client_ids),
            initial_parameters=ndarrays_to_parameters(initial_params),
            on_fit_config_fn=lambda rnd: {"local_epochs": 1},
            evaluate_fn=evaluate_fn,
        )
        return ServerAppComponents(strategy=strategy, config=ServerConfig(num_rounds=EDGE_ROUNDS))

    # Run edge-level FL
    run_simulation(
        client_app=fl.client.ClientApp(client_fn=client_fn),
        server_app=ServerApp(server_fn=server_fn),
        num_supernodes=len(client_ids)
    )

    # Load and return the final edge parameters
    edge_model = Net().to(device)
    edge_model.load_state_dict(torch.load(os.path.join(MODEL_DIR, f"edge_{edge_id}_round_{EDGE_ROUNDS}.pth")))
    return get_parameters(edge_model)

#----------------------------------------------------------------
# Broadcast global model back to all clients for additional local training
#----------------------------------------------------------------
def pushback_to_clients(global_params: NDArrays):
    # Reuse client_fn from main simulation (round robin)
    def client_fn(ctx: Context) -> Client:
        pid = ctx.node_config["partition-id"]
        net = Net()
        trainloader, valloader, _ = load_datasets(pid, TOTAL_CLIENTS)
        return FlowerClient(pid, net, trainloader, valloader).to_client()

    def server_fn(ctx: Context) -> ServerAppComponents:
        strategy = FedAvg(
            fraction_fit=1.0,
            fraction_evaluate=0.0,
            min_fit_clients=TOTAL_CLIENTS,
            min_available_clients=TOTAL_CLIENTS,
            initial_parameters=ndarrays_to_parameters(global_params),
            on_fit_config_fn=lambda rnd: {"local_epochs": LOCAL_EPOCHS_PUSH}
        )
        return ServerAppComponents(strategy=strategy, config=ServerConfig(num_rounds=1))

    run_simulation(
        client_app=fl.client.ClientApp(client_fn=client_fn),
        server_app=ServerApp(server_fn=server_fn),
        num_supernodes=TOTAL_CLIENTS
    )

#----------------------------------------------------------------
# Main: Outer global rounds with HFL and pushback
#----------------------------------------------------------------
def main():
    # Initialize global parameters
    global_params = get_parameters(Net())

    # Split client IDs among edges
    all_ids = list(range(TOTAL_CLIENTS))
    edges = [all_ids[i*CLIENTS_PER_EDGE:(i+1)*CLIENTS_PER_EDGE] for i in range(NUM_EDGE_SERVERS)]

    for gr in range(1, GLOBAL_ROUNDS+1):
        print(f"\n=== Global Round {gr} ===")
        # 1) Edge-level training & aggregation
        edge_params_list = []
        for eid, cids in enumerate(edges):
            print(f"Simulating Edge {eid} with clients {cids}")
            edge_params = simulate_edge(eid, cids, global_params)
            edge_params_list.append(edge_params)

        # 2) Cloud aggregation of edge models
        global_params = [
            np.mean([ep[i] for ep in edge_params_list], axis=0)
            for i in range(len(edge_params_list[0]))
        ]
        # Save global model
        global_model = Net().to(device)
        set_parameters(global_model, global_params)
        torch.save(global_model.state_dict(), os.path.join(MODEL_DIR, f"global_round_{gr}.pth"))
        print(f"Saved global model for round {gr}")

        # 3) Pushback: broadcast new global to all clients for extra local training
        print(f"Pushback: clients training again on global model round {gr}")
        pushback_to_clients(global_params)

    print("\nHierarchical FL with pushback complete.")

if __name__ == "__main__":
    main()


[92mINFO [0m:      Starting Flower ServerApp, config: num_rounds=3, no round_timeout
[92mINFO [0m:      
[92mINFO [0m:      [INIT]
[92mINFO [0m:      Using initial global parameters provided by strategy
[92mINFO [0m:      Starting evaluation of initial global parameters
[92mINFO [0m:      Evaluation returned no results (`None`)
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 1]
[92mINFO [0m:      configure_fit: strategy sampled 5 clients (out of 5)



=== Global Round 1 ===
Simulating Edge 0 with clients [0, 1, 2, 3, 4]


[36m(pid=26543)[0m 2025-05-19 20:10:21.434222: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
[36m(pid=26543)[0m E0000 00:00:1747685421.463945   26543 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
[36m(pid=26543)[0m E0000 00:00:1747685421.473825   26543 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
[36m(ClientAppActor pid=26543)[0m see the appropriate new directories, set the environment variable
[36m(ClientAppActor pid=26543)[0m `JUPYTER_PLATFORM_DIRS=1` and then run `jupyter --paths`.
[36m(ClientAppActor pid=26543)[0m The use of platformdirs will be the default in `jupyter_core` v6
[36m(ClientAppActor pid=26543)[0m   from jupyter_core.paths import jupyter_

Simulating Edge 1 with clients [5, 6, 7, 8, 9]


[36m(pid=27347)[0m 2025-05-19 20:12:07.251628: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
[36m(pid=27347)[0m E0000 00:00:1747685527.280626   27347 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
[36m(pid=27347)[0m E0000 00:00:1747685527.288351   27347 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
[36m(ClientAppActor pid=27347)[0m see the appropriate new directories, set the environment variable
[36m(ClientAppActor pid=27347)[0m `JUPYTER_PLATFORM_DIRS=1` and then run `jupyter --paths`.
[36m(ClientAppActor pid=27347)[0m The use of platformdirs will be the default in `jupyter_core` v6
[36m(ClientAppActor pid=27347)[0m   from jupyter_core.paths import jupyter_

Saved global model for round 1
Pushback: clients training again on global model round 1


[36m(pid=28130)[0m 2025-05-19 20:13:45.440417: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
[36m(pid=28130)[0m E0000 00:00:1747685625.464984   28130 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
[36m(pid=28130)[0m E0000 00:00:1747685625.472170   28130 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
[36m(ClientAppActor pid=28130)[0m see the appropriate new directories, set the environment variable
[36m(ClientAppActor pid=28130)[0m `JUPYTER_PLATFORM_DIRS=1` and then run `jupyter --paths`.
[36m(ClientAppActor pid=28130)[0m The use of platformdirs will be the default in `jupyter_core` v6
[36m(ClientAppActor pid=28130)[0m   from jupyter_core.paths import jupyter_


=== Global Round 2 ===
Simulating Edge 0 with clients [0, 1, 2, 3, 4]


[36m(pid=28759)[0m 2025-05-19 20:15:00.697503: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
[36m(pid=28759)[0m E0000 00:00:1747685700.739551   28759 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
[36m(pid=28759)[0m E0000 00:00:1747685700.746790   28759 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
[36m(ClientAppActor pid=28759)[0m see the appropriate new directories, set the environment variable
[36m(ClientAppActor pid=28759)[0m `JUPYTER_PLATFORM_DIRS=1` and then run `jupyter --paths`.
[36m(ClientAppActor pid=28759)[0m The use of platformdirs will be the default in `jupyter_core` v6
[36m(ClientAppActor pid=28759)[0m   from jupyter_core.paths import jupyter_

Simulating Edge 1 with clients [5, 6, 7, 8, 9]


[36m(pid=29578)[0m 2025-05-19 20:16:49.136363: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
[36m(pid=29578)[0m E0000 00:00:1747685809.177566   29578 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
[36m(pid=29578)[0m E0000 00:00:1747685809.189184   29578 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
[36m(ClientAppActor pid=29578)[0m see the appropriate new directories, set the environment variable
[36m(ClientAppActor pid=29578)[0m `JUPYTER_PLATFORM_DIRS=1` and then run `jupyter --paths`.
[36m(ClientAppActor pid=29578)[0m The use of platformdirs will be the default in `jupyter_core` v6
[36m(ClientAppActor pid=29578)[0m   from jupyter_core.paths import jupyter_

Saved global model for round 2
Pushback: clients training again on global model round 2


[36m(pid=30390)[0m 2025-05-19 20:18:35.731075: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
[36m(pid=30390)[0m E0000 00:00:1747685915.776644   30390 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
[36m(pid=30390)[0m E0000 00:00:1747685915.789194   30390 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
[36m(ClientAppActor pid=30390)[0m see the appropriate new directories, set the environment variable
[36m(ClientAppActor pid=30390)[0m `JUPYTER_PLATFORM_DIRS=1` and then run `jupyter --paths`.
[36m(ClientAppActor pid=30390)[0m The use of platformdirs will be the default in `jupyter_core` v6
[36m(ClientAppActor pid=30390)[0m   from jupyter_core.paths import jupyter_


Hierarchical FL with pushback complete.
