In [4]:
%cd .

d:\MachineLearning\federated_vae\main


In [5]:
from collections import OrderedDict
from typing import List, Tuple

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from datasets.utils.logging import disable_progress_bar
from torch.utils.data import DataLoader

import flwr
from flwr.client import Client, ClientApp, NumPyClient
from flwr.common import Metrics, Context
from flwr.server import ServerApp, ServerConfig, ServerAppComponents
from flwr.server.strategy import FedAvg
from flwr.simulation import run_simulation
from flwr_datasets import FederatedDataset

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Training on {DEVICE}")
print(f"Flower {flwr.__version__} / PyTorch {torch.__version__}")
disable_progress_bar()

Training on cuda
Flower 1.19.0 / PyTorch 2.5.1+cu121


In [6]:
NUM_CLIENTS = 2
BATCH_SIZE = 256
NUM_EPOCHS = 100

from test_flwr import get_all_vocab, split_data
vocab = get_all_vocab(["../data/20NG"])
datasets = split_data(dir = "../data/20NG", num_split=20, vocab = vocab, batch_size= BATCH_SIZE)[:NUM_CLIENTS]

train_size:  11314
test_size:  7532
vocab_size:  5000
average length: 110.543


loading train texts: 100%|██████████| 565/565 [00:00<00:00, 9738.00it/s]
parsing texts: 100%|██████████| 565/565 [00:00<00:00, 11789.37it/s]
loading train texts: 100%|██████████| 565/565 [00:00<00:00, 8739.26it/s]
parsing texts: 100%|██████████| 565/565 [00:00<00:00, 10508.26it/s]
loading train texts: 100%|██████████| 565/565 [00:00<00:00, 8431.08it/s]
parsing texts: 100%|██████████| 565/565 [00:00<00:00, 10921.46it/s]
loading train texts: 100%|██████████| 565/565 [00:00<00:00, 9202.96it/s]
parsing texts: 100%|██████████| 565/565 [00:00<00:00, 11279.57it/s]
loading train texts: 100%|██████████| 565/565 [00:00<00:00, 9010.44it/s]
parsing texts: 100%|██████████| 565/565 [00:00<00:00, 11033.99it/s]
loading train texts: 100%|██████████| 565/565 [00:00<00:00, 8500.85it/s]
parsing texts: 100%|██████████| 565/565 [00:00<00:00, 10163.67it/s]
loading train texts: 100%|██████████| 565/565 [00:00<00:00, 7151.77it/s]
parsing texts: 100%|██████████| 565/565 [00:00<00:00, 9099.57it/s]
loading train 

In [7]:
from model.ETM import ETM
from trainer.basic_trainer import BasicTrainer

# net = ETM(len(vocab)).to(DEVICE)

# trainer = BasicTrainer(model = net, dataset = datasets[0], epochs=NUM_EPOCHS, batch_size=BATCH_SIZE,
#                        log_interval=10)

# trainer.train()


In [8]:
# res = trainer.get_top_words()
# print(res)

In [9]:
def set_parameters(net, parameters: List[np.ndarray]):
    params_dict = zip(net.state_dict().keys(), parameters)
    state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
    net.load_state_dict(state_dict, strict=True)


def get_parameters(net) -> List[np.ndarray]:
    return [val.cpu().numpy() for _, val in net.state_dict().items()]

In [10]:
from data.basic_dataset import RawDataset
class FlowerClient(NumPyClient):
  def __init__(self, net, dataset : RawDataset):
    self.net = net
    self.dataset = dataset
    self.trainer = BasicTrainer(net, dataset, epochs = 1, log_interval=10, device = DEVICE)

  # return the current local model parameters
  def get_parameters(self, config):
    return get_parameters(self.net)

  # receive global parameter, train, return updated model to server
  def fit(self, parameters, config):
    set_parameters(self.net, parameters)
    self.trainer.train()

    return get_parameters(self.net), len(self.dataset.train_texts), {}

  # receive global parameter, evaluate model from local's data, return the evaluation result
  def evaluate(self, parameters, config):
    set_parameters(self.net, parameters)
    loss, acc = -1, -1
    return float(loss), 0, {"accuracy":float(acc)}
  


test = FlowerClient(ETM(len(vocab)), datasets[0])

In [11]:
def client_fn(context: Context) -> Client:
    """Create a Flower client representing a single organization."""

    # Load model
    net = ETM(len(vocab)).to(DEVICE)

    # Load data (CIFAR-10)
    # Note: each client gets a different trainloader/valloader, so each client
    # will train and evaluate on their own unique data partition
    # Read the node_config to fetch data partition associated to this node
    partition_id = context.node_config["partition-id"]
    dataset = datasets[partition_id]

    # Create a single Flower client representing a single organization
    # FlowerClient is a subclass of NumPyClient, so we need to call .to_client()
    # to convert it to a subclass of `flwr.client.Client`
    return FlowerClient(net, dataset).to_client()


# Create the ClientApp
client = ClientApp(client_fn=client_fn)

In [12]:
print(test.net.parameters)

<bound method Module.parameters of ETM(
  (encoder1): Sequential(
    (0): Linear(in_features=5000, out_features=800, bias=True)
    (1): ReLU()
    (2): Linear(in_features=800, out_features=800, bias=True)
    (3): ReLU()
    (4): Dropout(p=0.0, inplace=False)
  )
  (fc21): Linear(in_features=800, out_features=50, bias=True)
  (fc22): Linear(in_features=800, out_features=50, bias=True)
)>


In [13]:
def server_fn(context: Context) -> ServerAppComponents:
    """Construct components that set the ServerApp behaviour.

    You can use the settings in `context.run_config` to parameterize the
    construction of all elements (e.g the strategy or the number of rounds)
    wrapped in the returned ServerAppComponents object.
    """

    # Configure the server for 5 rounds of training
    config = ServerConfig(num_rounds=5)
    # Create FedAvg strategy
    strategy = FedAvg(
        fraction_fit=1.0,
        fraction_evaluate=0.5,
        min_fit_clients=NUM_CLIENTS,
        min_available_clients=NUM_CLIENTS,
    )
    return ServerAppComponents(strategy=strategy, config=config)


# Create the ServerApp
server = ServerApp(server_fn=server_fn)

In [14]:
# Specify the resources each of your clients need
# By default, each client will be allocated 1x CPU and 0x GPUs
backend_config = {"client_resources": {"num_cpus": 1, "num_gpus": 0.0}}

# When running on GPU, assign an entire GPU for each client
if DEVICE == "cuda":
    backend_config = {"client_resources": {"num_cpus": 1, "num_gpus": 1.0}}
    # Refer to our Flower framework documentation for more details about Flower simulations
    # and how to set up the `backend_config`

In [15]:
# Run simulation
run_simulation(
    server_app=server,
    client_app=client,
    num_supernodes=NUM_CLIENTS,
    backend_config=backend_config,
    verbose_logging=True
)

[94mDEBUG 2025-07-12 22:48:56,926[0m:     Asyncio event loop already running.
[94mDEBUG 2025-07-12 22:48:56,928[0m:     Logger propagate set to False
[94mDEBUG 2025-07-12 22:48:56,929[0m:     Pre-registering run with id 2335252092985170524
[94mDEBUG 2025-07-12 22:48:56,933[0m:     Using InMemoryState
[94mDEBUG 2025-07-12 22:48:56,934[0m:     Using InMemoryState
[92mINFO 2025-07-12 22:48:56,937[0m:      Starting Flower ServerApp, config: num_rounds=5, no round_timeout
[94mDEBUG 2025-07-12 22:48:56,939[0m:     Using InMemoryState
[92mINFO 2025-07-12 22:48:56,940[0m:      
[94mDEBUG 2025-07-12 22:48:56,941[0m:     Registered 2 nodes
[94mDEBUG 2025-07-12 22:48:56,942[0m:     Supported backends: ['ray']
[94mDEBUG 2025-07-12 22:48:56,943[0m:     Initialising: RayBackend
[92mINFO 2025-07-12 22:48:56,943[0m:      [INIT]
[94mDEBUG 2025-07-12 22:48:56,943[0m:     Backend config: {'client_resources': {'num_cpus': 1, 'num_gpus': 1.0}, 'init_args': {}, 'actor': {'tensorflo

[36m(ClientAppActor pid=22140)[0m Epoch: 000 | Loss: 1960.4237060546875


[92mINFO 2025-07-12 22:49:20,698[0m:      aggregate_fit: received 2 results and 0 failures
[92mINFO 2025-07-12 22:49:20,778[0m:      configure_evaluate: strategy sampled 2 clients (out of 2)


[36m(ClientAppActor pid=22140)[0m Epoch: 000 | Loss: 2313.81982421875


[92mINFO 2025-07-12 22:49:21,675[0m:      aggregate_evaluate: received 2 results and 0 failures
[91mERROR 2025-07-12 22:49:21,677[0m:     ServerApp thread raised an exception: float division by zero
[91mERROR 2025-07-12 22:49:21,819[0m:     Traceback (most recent call last):
  File "d:\Anaconda\envs\TMenv\lib\site-packages\flwr\simulation\run_simulation.py", line 268, in server_th_with_start_checks
    updated_context = _run(
  File "d:\Anaconda\envs\TMenv\lib\site-packages\flwr\server\run_serverapp.py", line 62, in run
    server_app(grid=grid, context=context)
  File "d:\Anaconda\envs\TMenv\lib\site-packages\flwr\server\server_app.py", line 166, in __call__
    start_grid(
  File "d:\Anaconda\envs\TMenv\lib\site-packages\flwr\server\compat\app.py", line 90, in start_grid
    hist = run_fl(
  File "d:\Anaconda\envs\TMenv\lib\site-packages\flwr\server\server.py", line 492, in run_fl
    hist, elapsed_time = server.fit(
  File "d:\Anaconda\envs\TMenv\lib\site-packages\flwr\server\

RuntimeError: Exception in ServerApp thread