In [1]:
%cd .

d:\MachineLearning\federated_vae\main


In [2]:
from collections import OrderedDict
from typing import List, Tuple

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from datasets.utils.logging import disable_progress_bar
from torch.utils.data import DataLoader

import flwr
from flwr.client import Client, ClientApp, NumPyClient
from flwr.common import Metrics, Context
from flwr.server import ServerApp, ServerConfig, ServerAppComponents
from flwr.server.strategy import FedAvg
from flwr.simulation import run_simulation
from flwr_datasets import FederatedDataset

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Training on {DEVICE}")
print(f"Flower {flwr.__version__} / PyTorch {torch.__version__}")
disable_progress_bar()

  from .autonotebook import tqdm as notebook_tqdm
2025-07-12 23:21:09,406	INFO util.py:154 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.


Training on cuda
Flower 1.19.0 / PyTorch 2.5.1+cu121


In [None]:
NUM_CLIENTS = 2
BATCH_SIZE = 256
NUM_ROUNDS = 30

from test_flwr import get_all_vocab, split_data
vocab = get_all_vocab(["../data/20NG"])
datasets = split_data(dir = "../data/20NG", num_split=NUM_CLIENTS, vocab = vocab, batch_size= BATCH_SIZE)

train_size:  11314
test_size:  7532
vocab_size:  5000
average length: 110.543


loading train texts: 100%|██████████| 5657/5657 [00:00<00:00, 6586.48it/s]
parsing texts: 100%|██████████| 5657/5657 [00:00<00:00, 8763.46it/s]
loading train texts: 100%|██████████| 5657/5657 [00:00<00:00, 6141.97it/s]
parsing texts: 100%|██████████| 5657/5657 [00:00<00:00, 6692.00it/s]


In [4]:
from model.ETM import ETM
from trainer.basic_trainer import BasicTrainer

# net = ETM(len(vocab)).to(DEVICE)

# trainer = BasicTrainer(model = net, dataset = datasets[0], epochs=NUM_EPOCHS, batch_size=BATCH_SIZE,
#                        log_interval=10)

# trainer.train()


In [5]:
# res = trainer.get_top_words()
# print(res)

In [6]:
def set_parameters(net, parameters: List[np.ndarray]):
    params_dict = zip(net.state_dict().keys(), parameters)
    state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
    net.load_state_dict(state_dict, strict=True)


def get_parameters(net) -> List[np.ndarray]:
    return [val.cpu().numpy() for _, val in net.state_dict().items()]

In [None]:
from data.basic_dataset import RawDataset
class FlowerClient(NumPyClient):
  def __init__(self, net, dataset : RawDataset, id):
    self.net = net
    self.dataset = dataset
    self.trainer = BasicTrainer(net, dataset, epochs = 1, log_interval=10, device = DEVICE)
    self.id = id
    self.save_dir = "model_parameters/"
    self.round_id = 0

  # return the current local model parameters
  def get_parameters(self, config):
    return get_parameters(self.net)

  # receive global parameter, train, return updated model to server
  def fit(self, parameters, config):
    self.round_id += 1
    set_parameters(self.net, parameters)
    self.trainer.train(model_name = f"ETM - Client {self.id}")

    if (self.round_id == NUM_ROUNDS)
      self._save_parameters(self.net.parameters())
    return get_parameters(self.net), len(self.dataset.train_texts), {}

  # receive global parameter, evaluate model from local's data, return the evaluation result
  def evaluate(self, parameters, config):
    set_parameters(self.net, parameters)
    loss, acc = -1, -1
    return float(loss), 1, {"accuracy":float(acc)}
  
  def _save_parameters(self, parameters, prefix="ETM"):
      """Hàm helper để lưu parameters"""
      # Chuyển Flower Parameters sang state_dict
      params_dict = zip(self.net.state_dict().keys(), parameters)
      state_dict = {k: torch.tensor(v) for k, v in params_dict}
        
      # Tạo tên file
      filename = f"{prefix}_client{self.id}.pth"
      save_path = self.save_dir + filename
        
      # Lưu file
      torch.save(state_dict, save_path)
      print(f"Saved client{self.id}'s model at {save_path}")

test = FlowerClient(ETM(len(vocab)), datasets[0], 0)

In [8]:
test._save_parameters(test.net.parameters())

Saved client0's model at model_parameters/ETM_client0.pth


In [9]:
def client_fn(context: Context) -> Client:
    """Create a Flower client representing a single organization."""

    # Load model
    net = ETM(len(vocab)).to(DEVICE)

    # Load data (CIFAR-10)
    # Note: each client gets a different trainloader/valloader, so each client
    # will train and evaluate on their own unique data partition
    # Read the node_config to fetch data partition associated to this node
    partition_id = context.node_config["partition-id"]
    dataset = datasets[partition_id]

    # Create a single Flower client representing a single organization
    # FlowerClient is a subclass of NumPyClient, so we need to call .to_client()
    # to convert it to a subclass of `flwr.client.Client`
    return FlowerClient(net, dataset, partition_id).to_client()


# Create the ClientApp
client = ClientApp(client_fn=client_fn)

In [10]:
print(test.net.parameters)

<bound method Module.parameters of ETM(
  (encoder1): Sequential(
    (0): Linear(in_features=5000, out_features=800, bias=True)
    (1): ReLU()
    (2): Linear(in_features=800, out_features=800, bias=True)
    (3): ReLU()
    (4): Dropout(p=0.0, inplace=False)
  )
  (fc21): Linear(in_features=800, out_features=50, bias=True)
  (fc22): Linear(in_features=800, out_features=50, bias=True)
)>


In [None]:
def server_fn(context: Context) -> ServerAppComponents:
    """Construct components that set the ServerApp behaviour.

    You can use the settings in `context.run_config` to parameterize the
    construction of all elements (e.g the strategy or the number of rounds)
    wrapped in the returned ServerAppComponents object.
    """

    # Configure the server for 30 rounds of training
    config = ServerConfig(num_rounds=NUM_ROUNDS)
    # Create FedAvg strategy
    strategy = FedAvg(
        fraction_fit=1.0,
        fraction_evaluate=0.5,
        min_fit_clients=NUM_CLIENTS,
        min_available_clients=NUM_CLIENTS,
    )
    return ServerAppComponents(strategy=strategy, config=config)


# Create the ServerApp
server = ServerApp(server_fn=server_fn)

In [12]:
# Specify the resources each of your clients need
# By default, each client will be allocated 1x CPU and 0x GPUs
backend_config = {"client_resources": {"num_cpus": 1, "num_gpus": 0.0}}

# When running on GPU, assign an entire GPU for each client
if DEVICE == "cuda":
    backend_config = {"client_resources": {"num_cpus": 1, "num_gpus": 1.0}}
    # Refer to our Flower framework documentation for more details about Flower simulations
    # and how to set up the `backend_config`

In [13]:
# Run simulation
run_simulation(
    server_app=server,
    client_app=client,
    num_supernodes=NUM_CLIENTS,
    backend_config=backend_config,
    verbose_logging=True
)

[94mDEBUG 2025-07-12 23:21:26,745[0m:     Asyncio event loop already running.
[94mDEBUG 2025-07-12 23:21:26,746[0m:     Logger propagate set to False
[94mDEBUG 2025-07-12 23:21:26,748[0m:     Pre-registering run with id 1713423872093146814
[94mDEBUG 2025-07-12 23:21:26,750[0m:     Using InMemoryState
[94mDEBUG 2025-07-12 23:21:26,751[0m:     Using InMemoryState
[92mINFO 2025-07-12 23:21:26,760[0m:      Starting Flower ServerApp, config: num_rounds=30, no round_timeout
[94mDEBUG 2025-07-12 23:21:26,764[0m:     Using InMemoryState
[92mINFO 2025-07-12 23:21:26,766[0m:      
[94mDEBUG 2025-07-12 23:21:26,766[0m:     Registered 2 nodes
[94mDEBUG 2025-07-12 23:21:26,767[0m:     Supported backends: ['ray']
[92mINFO 2025-07-12 23:21:26,769[0m:      [INIT]
[94mDEBUG 2025-07-12 23:21:26,770[0m:     Initialising: RayBackend
[92mINFO 2025-07-12 23:21:26,771[0m:      Requesting initial parameters from one random client
[94mDEBUG 2025-07-12 23:21:26,773[0m:     Backend co

[36m(ClientAppActor pid=17872)[0m Client's model: ETM - Client 0
[36m(ClientAppActor pid=17872)[0m Epoch: 000 | Loss: 1954.833251953125
[36m(ClientAppActor pid=17872)[0m {}
[36m(ClientAppActor pid=17872)[0m Client's model: ETM - Client 1
[36m(ClientAppActor pid=17872)[0m Epoch: 000 | Loss: 1900.1961669921875
[36m(ClientAppActor pid=17872)[0m {}


[92mINFO 2025-07-12 23:22:00,843[0m:      aggregate_fit: received 2 results and 0 failures
[92mINFO 2025-07-12 23:22:00,940[0m:      configure_evaluate: strategy sampled 2 clients (out of 2)
[92mINFO 2025-07-12 23:22:03,836[0m:      aggregate_evaluate: received 2 results and 0 failures
[92mINFO 2025-07-12 23:22:03,839[0m:      
[92mINFO 2025-07-12 23:22:03,839[0m:      [ROUND 2]
[92mINFO 2025-07-12 23:22:03,840[0m:      configure_fit: strategy sampled 2 clients (out of 2)


[36m(ClientAppActor pid=17872)[0m Client's model: ETM - Client 0
[36m(ClientAppActor pid=17872)[0m Epoch: 000 | Loss: 1610.037353515625
[36m(ClientAppActor pid=17872)[0m {}
[36m(ClientAppActor pid=17872)[0m Client's model: ETM - Client 1
[36m(ClientAppActor pid=17872)[0m Epoch: 000 | Loss: 1565.8861083984375
[36m(ClientAppActor pid=17872)[0m {}


[92mINFO 2025-07-12 23:22:11,258[0m:      aggregate_fit: received 2 results and 0 failures
[92mINFO 2025-07-12 23:22:11,327[0m:      configure_evaluate: strategy sampled 2 clients (out of 2)
[92mINFO 2025-07-12 23:22:13,522[0m:      aggregate_evaluate: received 2 results and 0 failures
[92mINFO 2025-07-12 23:22:13,523[0m:      
[92mINFO 2025-07-12 23:22:13,524[0m:      [ROUND 3]
[92mINFO 2025-07-12 23:22:13,524[0m:      configure_fit: strategy sampled 2 clients (out of 2)


[36m(ClientAppActor pid=17872)[0m Client's model: ETM - Client 0
[36m(ClientAppActor pid=17872)[0m Epoch: 000 | Loss: 1389.2437744140625
[36m(ClientAppActor pid=17872)[0m {}
[36m(ClientAppActor pid=17872)[0m Client's model: ETM - Client 1
[36m(ClientAppActor pid=17872)[0m Epoch: 000 | Loss: 1349.533203125
[36m(ClientAppActor pid=17872)[0m {}


[92mINFO 2025-07-12 23:22:20,822[0m:      aggregate_fit: received 2 results and 0 failures
[92mINFO 2025-07-12 23:22:20,893[0m:      configure_evaluate: strategy sampled 2 clients (out of 2)
[92mINFO 2025-07-12 23:22:24,494[0m:      aggregate_evaluate: received 2 results and 0 failures
[92mINFO 2025-07-12 23:22:24,504[0m:      
[92mINFO 2025-07-12 23:22:24,505[0m:      [ROUND 4]
[92mINFO 2025-07-12 23:22:24,506[0m:      configure_fit: strategy sampled 2 clients (out of 2)


[36m(ClientAppActor pid=17872)[0m Client's model: ETM - Client 0
[36m(ClientAppActor pid=17872)[0m Epoch: 000 | Loss: 1231.1519775390625
[36m(ClientAppActor pid=17872)[0m {}
[36m(ClientAppActor pid=17872)[0m Client's model: ETM - Client 1
[36m(ClientAppActor pid=17872)[0m Epoch: 000 | Loss: 1200.928466796875
[36m(ClientAppActor pid=17872)[0m {}


[92mINFO 2025-07-12 23:22:33,882[0m:      aggregate_fit: received 2 results and 0 failures
[92mINFO 2025-07-12 23:22:33,938[0m:      configure_evaluate: strategy sampled 2 clients (out of 2)
[92mINFO 2025-07-12 23:22:37,289[0m:      aggregate_evaluate: received 2 results and 0 failures
[92mINFO 2025-07-12 23:22:37,290[0m:      
[92mINFO 2025-07-12 23:22:37,291[0m:      [ROUND 5]
[92mINFO 2025-07-12 23:22:37,292[0m:      configure_fit: strategy sampled 2 clients (out of 2)


[36m(ClientAppActor pid=17872)[0m Client's model: ETM - Client 0
[36m(ClientAppActor pid=17872)[0m Epoch: 000 | Loss: 1124.189208984375
[36m(ClientAppActor pid=17872)[0m {}
[36m(ClientAppActor pid=17872)[0m Client's model: ETM - Client 1
[36m(ClientAppActor pid=17872)[0m Epoch: 000 | Loss: 1095.61474609375
[36m(ClientAppActor pid=17872)[0m {}


[92mINFO 2025-07-12 23:22:46,042[0m:      aggregate_fit: received 2 results and 0 failures
[92mINFO 2025-07-12 23:22:46,121[0m:      configure_evaluate: strategy sampled 2 clients (out of 2)
[92mINFO 2025-07-12 23:22:57,317[0m:      aggregate_evaluate: received 2 results and 0 failures
[92mINFO 2025-07-12 23:22:58,027[0m:      
[92mINFO 2025-07-12 23:22:58,127[0m:      [ROUND 6]
[92mINFO 2025-07-12 23:22:58,127[0m:      configure_fit: strategy sampled 2 clients (out of 2)


[36m(ClientAppActor pid=17872)[0m Client's model: ETM - Client 0
[36m(ClientAppActor pid=17872)[0m Epoch: 000 | Loss: 1044.543701171875
[36m(ClientAppActor pid=17872)[0m {}
[36m(ClientAppActor pid=17872)[0m Client's model: ETM - Client 1
[36m(ClientAppActor pid=17872)[0m Epoch: 000 | Loss: 1019.7268676757812
[36m(ClientAppActor pid=17872)[0m {}


[92mINFO 2025-07-12 23:23:18,579[0m:      aggregate_fit: received 2 results and 0 failures
[92mINFO 2025-07-12 23:23:18,638[0m:      configure_evaluate: strategy sampled 2 clients (out of 2)
[92mINFO 2025-07-12 23:23:24,006[0m:      aggregate_evaluate: received 2 results and 0 failures
[92mINFO 2025-07-12 23:23:24,008[0m:      
[92mINFO 2025-07-12 23:23:24,008[0m:      [ROUND 7]
[92mINFO 2025-07-12 23:23:24,008[0m:      configure_fit: strategy sampled 2 clients (out of 2)


[36m(ClientAppActor pid=17872)[0m Client's model: ETM - Client 0
[36m(ClientAppActor pid=17872)[0m Epoch: 000 | Loss: 987.4068603515625
[36m(ClientAppActor pid=17872)[0m {}
[36m(ClientAppActor pid=17872)[0m Client's model: ETM - Client 1
[36m(ClientAppActor pid=17872)[0m Epoch: 000 | Loss: 963.5377197265625
[36m(ClientAppActor pid=17872)[0m {}


[92mINFO 2025-07-12 23:23:42,044[0m:      aggregate_fit: received 2 results and 0 failures
[92mINFO 2025-07-12 23:23:42,120[0m:      configure_evaluate: strategy sampled 2 clients (out of 2)
[92mINFO 2025-07-12 23:23:44,858[0m:      aggregate_evaluate: received 2 results and 0 failures
[92mINFO 2025-07-12 23:23:44,861[0m:      
[92mINFO 2025-07-12 23:23:44,862[0m:      [ROUND 8]
[92mINFO 2025-07-12 23:23:44,863[0m:      configure_fit: strategy sampled 2 clients (out of 2)


[36m(ClientAppActor pid=17872)[0m Client's model: ETM - Client 0
[36m(ClientAppActor pid=17872)[0m Epoch: 000 | Loss: 942.1895751953125
[36m(ClientAppActor pid=17872)[0m {}
[36m(ClientAppActor pid=17872)[0m Client's model: ETM - Client 1
[36m(ClientAppActor pid=17872)[0m Epoch: 000 | Loss: 919.9139404296875
[36m(ClientAppActor pid=17872)[0m {}


[33m(raylet)[0m [2025-07-12 23:24:02,763 C 19496 11764] (raylet.exe) dlmalloc.cc:129:  Check failed: *handle != nullptr CreateFileMapping() failed. GetLastError() = 1455
[33m(raylet)[0m *** StackTrace Information ***
[33m(raylet)[0m unknown
[33m(raylet)[0m 
[91mERROR 2025-07-12 23:24:10,501[0m:     An exception was raised when processing a message by RayBackend
[91mERROR 2025-07-12 23:24:10,507[0m:     The actor 4c33155849fd78a47e1ec21501000000 is unavailable: The actor is temporarily unavailable: RpcError: RPC Error message: Connection reset; RPC Error details: . The task may or maynot have been executed on the actor.
[91mERROR 2025-07-12 23:24:10,684[0m:     Traceback (most recent call last):
  File "d:\Anaconda\envs\TMenv\lib\site-packages\flwr\server\superlink\fleet\vce\vce_api.py", line 112, in worker
    out_mssg, updated_context = backend.process_message(message, context)
  File "d:\Anaconda\envs\TMenv\lib\site-packages\flwr\server\superlink\fleet\vce\backend\rayba