In [1]:
from collections import OrderedDict
from typing import List, Tuple, Union, Optional

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from datasets.utils.logging import disable_progress_bar
from torch.utils.data import DataLoader

import flwr as fl
import numpy as np
import torch
import glob
import os
from flwr.common import ndarrays_to_parameters
from flwr.client import Client, ClientApp, NumPyClient
from flwr.common import Metrics, Context, FitRes, Parameters, Scalar
from flwr.server import ServerApp, ServerConfig, ServerAppComponents
from flwr.server.client_proxy import ClientProxy
from flwr.server.strategy import FedAvg
from flwr.simulation import run_simulation
from flwr_datasets import FederatedDataset

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Training on {DEVICE}")
print(f"Flower {fl.__version__} / PyTorch {torch.__version__}")
disable_progress_bar()

KeyboardInterrupt: 

In [None]:
NUM_CLIENTS = 2
BATCH_SIZE = 256
NUM_ROUNDS = 30

from test_flwr import get_all_vocab, split_data
vocab = get_all_vocab(["datasets/20NG"])
datasets = split_data(dir = "datasets/20NG", num_split=NUM_CLIENTS, vocab = vocab, batch_size= BATCH_SIZE)

train_size:  11314
test_size:  7532
vocab_size:  5000
average length: 110.543


loading train texts: 100%|██████████| 5657/5657 [00:00<00:00, 8589.03it/s]
parsing texts: 100%|██████████| 5657/5657 [00:00<00:00, 10600.52it/s]
loading train texts: 100%|██████████| 5657/5657 [00:00<00:00, 8589.95it/s]
parsing texts: 100%|██████████| 5657/5657 [00:00<00:00, 10606.07it/s]


In [None]:
from model.ETM import ETM
from trainer.basic_trainer import BasicTrainer



In [None]:
def set_parameters(net, parameters: List[np.ndarray]):
    params_dict = zip(net.state_dict().keys(), parameters)
    state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
    net.load_state_dict(state_dict, strict=True)


def get_parameters(net) -> List[np.ndarray]:
    return [val.cpu().numpy() for _, val in net.state_dict().items()]

In [None]:
from data.basic_dataset import RawDataset
class FlowerClient(NumPyClient):
  def __init__(self, net, dataset : RawDataset, id):
    self.net = net
    self.dataset = dataset
    self.trainer = BasicTrainer(net, dataset, epochs = 1, log_interval=10, device = DEVICE, save_model = True, save_interval=NUM_ROUNDS)
    self.id = id
    self.save_dir = "model_parameters/"
    self.round_id = 0
    self.total_round = NUM_ROUNDS

  # return the current local model parameters
  def get_parameters(self, config):
    return get_parameters(self.net)

  # receive global parameter, train, return updated model to server
  def fit(self, parameters, config):
    set_parameters(self.net, parameters)
    self.trainer.train(model_name = f"ETM_Client{self.id}")

    return get_parameters(self.net), len(self.dataset.train_texts), {}

  # receive global parameter, evaluate model from local's data, return the evaluation result
  def evaluate(self, parameters, config):
    set_parameters(self.net, parameters)
    loss, acc = -1, -1
    return float(loss), 1, {"accuracy":float(acc)}


test = FlowerClient(ETM(len(vocab)), datasets[0], 0)

In [None]:
def client_fn(context: Context) -> Client:
    """Create a Flower client representing a single organization."""

    # Load model
    net = ETM(len(vocab)).to(DEVICE)

    # Load data (CIFAR-10)
    # Note: each client gets a different trainloader/valloader, so each client
    # will train and evaluate on their own unique data partition
    # Read the node_config to fetch data partition associated to this node
    partition_id = context.node_config["partition-id"]
    dataset = datasets[partition_id]

    # Create a single Flower client representing a single organization
    # FlowerClient is a subclass of NumPyClient, so we need to call .to_client()
    # to convert it to a subclass of `flwr.client.Client`
    return FlowerClient(net, dataset, partition_id).to_client()


# Create the ClientApp
client = ClientApp(client_fn=client_fn)

In [None]:
class SaveModelStrategy(fl.server.strategy.FedAvg):
    def aggregate_fit(
        self,
        server_round: int,
        results: list[tuple[fl.server.client_proxy.ClientProxy, fl.common.FitRes]],
        failures: list[Union[tuple[ClientProxy, FitRes], BaseException]],
    ) -> tuple[Optional[Parameters], dict[str, Scalar]]:

        # Call aggregate_fit from base class (FedAvg) to aggregate parameters and metrics
        aggregated_parameters, aggregated_metrics = super().aggregate_fit(
            server_round, results, failures
        )

        if aggregated_parameters is not None:
            # Convert `Parameters` to `list[np.ndarray]`
            aggregated_ndarrays: list[np.ndarray] = fl.common.parameters_to_ndarrays(
                aggregated_parameters
            )

            # Save aggregated_ndarrays to disk
            if server_round % 10 == 0:
                print(f"Saving round {server_round} aggregated_ndarrays...")
                np.savez(f"model_parameters/model_round_{server_round}.npz", *aggregated_ndarrays)

        return aggregated_parameters, aggregated_metrics


# Create strategy and pass into ServerApp
def server_fn(context):
    strategy = SaveModelStrategy(
        fraction_fit=1.0,
        fraction_evaluate=0.5,
        min_fit_clients=NUM_CLIENTS,
        min_available_clients=NUM_CLIENTS,
    )
    config = ServerConfig(num_rounds=NUM_ROUNDS)
    return ServerAppComponents(strategy=strategy, config=config)


server = ServerApp(server_fn=server_fn)

In [None]:
def get_latest_server_model(net):
    list_of_files = [fname for fname in glob.glob("model_parameters/model_round_*")]
    latest_round_file = max(list_of_files, key=os.path.getctime)
    print("Loading pre-trained model from: ", latest_round_file)
    
    # Load NumPy arrays from .npz file
    with np.load(latest_round_file) as data:
        arrays = [data[f'arr_{i}'] for i in range(len(data.files))]
    
    # Convert to PyTorch state_dict
    state_dict = {k: torch.from_numpy(v) for k, v in zip(net.state_dict().keys(), arrays)}
    net.load_state_dict(state_dict)
    
    # Convert to Flower Parameters
    state_dict_ndarrays = [v.cpu().numpy() for v in net.state_dict().values()]
    parameters = fl.common.ndarrays_to_parameters(state_dict_ndarrays)
    return parameters

In [None]:
# Specify the resources each of your clients need
# By default, each client will be allocated 1x CPU and 0x GPUs
backend_config = {"client_resources": {"num_cpus": 1, "num_gpus": 0.0}}

# When running on GPU, assign an entire GPU for each client
if DEVICE == "cuda":
    backend_config = {"client_resources": {"num_cpus": 1, "num_gpus": 1.0}}
    # Refer to our Flower framework documentation for more details about Flower simulations
    # and how to set up the `backend_config`

In [None]:
# Run simulation
run_simulation(
    server_app=server,
    client_app=client,
    num_supernodes=NUM_CLIENTS,
    backend_config=backend_config,
    verbose_logging=True
)

[94mDEBUG 2025-07-14 18:13:31,990[0m:     Asyncio event loop already running.
[94mDEBUG 2025-07-14 18:13:31,991[0m:     Logger propagate set to False
[94mDEBUG 2025-07-14 18:13:31,991[0m:     Pre-registering run with id 16802060710903922778
[94mDEBUG 2025-07-14 18:13:31,992[0m:     Using InMemoryState
[94mDEBUG 2025-07-14 18:13:31,992[0m:     Using InMemoryState
[92mINFO 2025-07-14 18:13:31,995[0m:      Starting Flower ServerApp, config: num_rounds=30, no round_timeout
[92mINFO 2025-07-14 18:13:32,022[0m:      
[94mDEBUG 2025-07-14 18:13:32,021[0m:     Using InMemoryState
[92mINFO 2025-07-14 18:13:32,024[0m:      [INIT]
[92mINFO 2025-07-14 18:13:32,025[0m:      Requesting initial parameters from one random client
[94mDEBUG 2025-07-14 18:13:32,024[0m:     Registered 2 nodes
[94mDEBUG 2025-07-14 18:13:32,026[0m:     Supported backends: ['ray']
[94mDEBUG 2025-07-14 18:13:32,027[0m:     Initialising: RayBackend
[94mDEBUG 2025-07-14 18:13:32,027[0m:     Backend c

[36m(ClientAppActor pid=13816)[0m Client's model: ETM_Client0
[36m(ClientAppActor pid=13816)[0m Epoch: 000 | Loss: 1941.7176513671875


[92mINFO 2025-07-14 18:13:55,504[0m:      aggregate_fit: received 2 results and 0 failures


[36m(ClientAppActor pid=13816)[0m Client's model: ETM_Client1
[36m(ClientAppActor pid=13816)[0m Epoch: 000 | Loss: 1894.876220703125


[92mINFO 2025-07-14 18:13:55,570[0m:      configure_evaluate: strategy sampled 2 clients (out of 2)
[92mINFO 2025-07-14 18:13:57,169[0m:      aggregate_evaluate: received 2 results and 0 failures
[92mINFO 2025-07-14 18:13:57,171[0m:      
[92mINFO 2025-07-14 18:13:57,171[0m:      [ROUND 2]
[92mINFO 2025-07-14 18:13:57,172[0m:      configure_fit: strategy sampled 2 clients (out of 2)


[36m(ClientAppActor pid=13816)[0m Client's model: ETM_Client0
[36m(ClientAppActor pid=13816)[0m Epoch: 000 | Loss: 1607.691162109375


[92mINFO 2025-07-14 18:13:59,423[0m:      aggregate_fit: received 2 results and 0 failures


[36m(ClientAppActor pid=13816)[0m Client's model: ETM_Client1
[36m(ClientAppActor pid=13816)[0m Epoch: 000 | Loss: 1575.80712890625


[92mINFO 2025-07-14 18:13:59,507[0m:      configure_evaluate: strategy sampled 2 clients (out of 2)
[92mINFO 2025-07-14 18:14:01,151[0m:      aggregate_evaluate: received 2 results and 0 failures
[92mINFO 2025-07-14 18:14:01,153[0m:      
[92mINFO 2025-07-14 18:14:01,153[0m:      [ROUND 3]
[92mINFO 2025-07-14 18:14:01,154[0m:      configure_fit: strategy sampled 2 clients (out of 2)


[36m(ClientAppActor pid=13816)[0m Client's model: ETM_Client0
[36m(ClientAppActor pid=13816)[0m Epoch: 000 | Loss: 1390.7388916015625


[92mINFO 2025-07-14 18:14:03,498[0m:      aggregate_fit: received 2 results and 0 failures
[92mINFO 2025-07-14 18:14:03,566[0m:      configure_evaluate: strategy sampled 2 clients (out of 2)


[36m(ClientAppActor pid=13816)[0m Client's model: ETM_Client1
[36m(ClientAppActor pid=13816)[0m Epoch: 000 | Loss: 1354.6754150390625


[92mINFO 2025-07-14 18:14:05,913[0m:      aggregate_evaluate: received 2 results and 0 failures
[92mINFO 2025-07-14 18:14:05,913[0m:      
[92mINFO 2025-07-14 18:14:05,914[0m:      [ROUND 4]
[92mINFO 2025-07-14 18:14:05,914[0m:      configure_fit: strategy sampled 2 clients (out of 2)


[36m(ClientAppActor pid=13816)[0m Client's model: ETM_Client0
[36m(ClientAppActor pid=13816)[0m Epoch: 000 | Loss: 1227.1517333984375


[92mINFO 2025-07-14 18:14:08,497[0m:      aggregate_fit: received 2 results and 0 failures


[36m(ClientAppActor pid=13816)[0m Client's model: ETM_Client1
[36m(ClientAppActor pid=13816)[0m Epoch: 000 | Loss: 1197.91015625


[92mINFO 2025-07-14 18:14:08,643[0m:      configure_evaluate: strategy sampled 2 clients (out of 2)
[92mINFO 2025-07-14 18:14:11,336[0m:      aggregate_evaluate: received 2 results and 0 failures
[92mINFO 2025-07-14 18:14:11,337[0m:      
[92mINFO 2025-07-14 18:14:11,337[0m:      [ROUND 5]
[92mINFO 2025-07-14 18:14:11,338[0m:      configure_fit: strategy sampled 2 clients (out of 2)


[36m(ClientAppActor pid=13816)[0m Client's model: ETM_Client0
[36m(ClientAppActor pid=13816)[0m Epoch: 000 | Loss: 1113.5338134765625


[92mINFO 2025-07-14 18:14:14,163[0m:      aggregate_fit: received 2 results and 0 failures
[92mINFO 2025-07-14 18:14:14,226[0m:      configure_evaluate: strategy sampled 2 clients (out of 2)


[36m(ClientAppActor pid=13816)[0m Client's model: ETM_Client1
[36m(ClientAppActor pid=13816)[0m Epoch: 000 | Loss: 1088.9791259765625


[92mINFO 2025-07-14 18:14:15,989[0m:      aggregate_evaluate: received 2 results and 0 failures
[92mINFO 2025-07-14 18:14:15,990[0m:      
[92mINFO 2025-07-14 18:14:15,990[0m:      [ROUND 6]
[92mINFO 2025-07-14 18:14:15,991[0m:      configure_fit: strategy sampled 2 clients (out of 2)


[36m(ClientAppActor pid=13816)[0m Client's model: ETM_Client0
[36m(ClientAppActor pid=13816)[0m Epoch: 000 | Loss: 1032.0389404296875


[92mINFO 2025-07-14 18:14:18,193[0m:      aggregate_fit: received 2 results and 0 failures
[92mINFO 2025-07-14 18:14:18,258[0m:      configure_evaluate: strategy sampled 2 clients (out of 2)


[36m(ClientAppActor pid=13816)[0m Client's model: ETM_Client1
[36m(ClientAppActor pid=13816)[0m Epoch: 000 | Loss: 1008.400634765625


[92mINFO 2025-07-14 18:14:19,844[0m:      aggregate_evaluate: received 2 results and 0 failures
[92mINFO 2025-07-14 18:14:19,845[0m:      
[92mINFO 2025-07-14 18:14:19,846[0m:      [ROUND 7]
[92mINFO 2025-07-14 18:14:19,847[0m:      configure_fit: strategy sampled 2 clients (out of 2)


[36m(ClientAppActor pid=13816)[0m Client's model: ETM_Client0
[36m(ClientAppActor pid=13816)[0m Epoch: 000 | Loss: 973.5540161132812


[92mINFO 2025-07-14 18:14:22,142[0m:      aggregate_fit: received 2 results and 0 failures


[36m(ClientAppActor pid=13816)[0m Client's model: ETM_Client1
[36m(ClientAppActor pid=13816)[0m Epoch: 000 | Loss: 950.6398315429688


[92mINFO 2025-07-14 18:14:22,232[0m:      configure_evaluate: strategy sampled 2 clients (out of 2)
[92mINFO 2025-07-14 18:14:23,833[0m:      aggregate_evaluate: received 2 results and 0 failures
[92mINFO 2025-07-14 18:14:23,834[0m:      
[92mINFO 2025-07-14 18:14:23,835[0m:      [ROUND 8]
[92mINFO 2025-07-14 18:14:23,835[0m:      configure_fit: strategy sampled 2 clients (out of 2)


[36m(ClientAppActor pid=13816)[0m Client's model: ETM_Client0
[36m(ClientAppActor pid=13816)[0m Epoch: 000 | Loss: 928.7109985351562


[92mINFO 2025-07-14 18:14:26,503[0m:      aggregate_fit: received 2 results and 0 failures
[92mINFO 2025-07-14 18:14:26,571[0m:      configure_evaluate: strategy sampled 2 clients (out of 2)


[36m(ClientAppActor pid=13816)[0m Client's model: ETM_Client1
[36m(ClientAppActor pid=13816)[0m Epoch: 000 | Loss: 908.929443359375


[92mINFO 2025-07-14 18:14:28,149[0m:      aggregate_evaluate: received 2 results and 0 failures
[92mINFO 2025-07-14 18:14:28,150[0m:      
[92mINFO 2025-07-14 18:14:28,151[0m:      [ROUND 9]
[92mINFO 2025-07-14 18:14:28,152[0m:      configure_fit: strategy sampled 2 clients (out of 2)


[36m(ClientAppActor pid=13816)[0m Client's model: ETM_Client0
[36m(ClientAppActor pid=13816)[0m Epoch: 000 | Loss: 896.1220703125


[92mINFO 2025-07-14 18:14:30,375[0m:      aggregate_fit: received 2 results and 0 failures
[92mINFO 2025-07-14 18:14:30,441[0m:      configure_evaluate: strategy sampled 2 clients (out of 2)


[36m(ClientAppActor pid=13816)[0m Client's model: ETM_Client1
[36m(ClientAppActor pid=13816)[0m Epoch: 000 | Loss: 877.8868408203125


[92mINFO 2025-07-14 18:14:32,805[0m:      aggregate_evaluate: received 2 results and 0 failures
[92mINFO 2025-07-14 18:14:32,806[0m:      
[92mINFO 2025-07-14 18:14:32,806[0m:      [ROUND 10]
[92mINFO 2025-07-14 18:14:32,807[0m:      configure_fit: strategy sampled 2 clients (out of 2)


[36m(ClientAppActor pid=13816)[0m Client's model: ETM_Client0
[36m(ClientAppActor pid=13816)[0m Epoch: 000 | Loss: 872.2798461914062


[92mINFO 2025-07-14 18:14:36,921[0m:      aggregate_fit: received 2 results and 0 failures
[92mINFO 2025-07-14 18:14:37,060[0m:      configure_evaluate: strategy sampled 2 clients (out of 2)


[36m(ClientAppActor pid=13816)[0m Client's model: ETM_Client1
[36m(ClientAppActor pid=13816)[0m Epoch: 000 | Loss: 854.0929565429688
Saving round 10 aggregated_ndarrays...


[92mINFO 2025-07-14 18:14:39,621[0m:      aggregate_evaluate: received 2 results and 0 failures
[92mINFO 2025-07-14 18:14:39,623[0m:      
[92mINFO 2025-07-14 18:14:39,623[0m:      [ROUND 11]
[92mINFO 2025-07-14 18:14:39,624[0m:      configure_fit: strategy sampled 2 clients (out of 2)


[36m(ClientAppActor pid=13816)[0m Client's model: ETM_Client0
[36m(ClientAppActor pid=13816)[0m Epoch: 000 | Loss: 854.0578002929688


[92mINFO 2025-07-14 18:14:42,448[0m:      aggregate_fit: received 2 results and 0 failures
[92mINFO 2025-07-14 18:14:42,521[0m:      configure_evaluate: strategy sampled 2 clients (out of 2)


[36m(ClientAppActor pid=13816)[0m Client's model: ETM_Client1
[36m(ClientAppActor pid=13816)[0m Epoch: 000 | Loss: 836.7319946289062


[92mINFO 2025-07-14 18:14:44,662[0m:      aggregate_evaluate: received 2 results and 0 failures
[92mINFO 2025-07-14 18:14:44,663[0m:      
[92mINFO 2025-07-14 18:14:44,664[0m:      [ROUND 12]
[92mINFO 2025-07-14 18:14:44,664[0m:      configure_fit: strategy sampled 2 clients (out of 2)


[36m(ClientAppActor pid=13816)[0m Client's model: ETM_Client0
[36m(ClientAppActor pid=13816)[0m Epoch: 000 | Loss: 839.989990234375


[92mINFO 2025-07-14 18:14:47,447[0m:      aggregate_fit: received 2 results and 0 failures
[92mINFO 2025-07-14 18:14:47,517[0m:      configure_evaluate: strategy sampled 2 clients (out of 2)


[36m(ClientAppActor pid=13816)[0m Client's model: ETM_Client1
[36m(ClientAppActor pid=13816)[0m Epoch: 000 | Loss: 822.9850463867188


[92mINFO 2025-07-14 18:14:49,188[0m:      aggregate_evaluate: received 2 results and 0 failures
[92mINFO 2025-07-14 18:14:49,191[0m:      
[92mINFO 2025-07-14 18:14:49,193[0m:      [ROUND 13]
[92mINFO 2025-07-14 18:14:49,196[0m:      configure_fit: strategy sampled 2 clients (out of 2)


[36m(ClientAppActor pid=13816)[0m Client's model: ETM_Client0
[36m(ClientAppActor pid=13816)[0m Epoch: 000 | Loss: 829.7274780273438


[92mINFO 2025-07-14 18:14:51,638[0m:      aggregate_fit: received 2 results and 0 failures


[36m(ClientAppActor pid=13816)[0m Client's model: ETM_Client1
[36m(ClientAppActor pid=13816)[0m Epoch: 000 | Loss: 812.7026977539062


[92mINFO 2025-07-14 18:14:51,726[0m:      configure_evaluate: strategy sampled 2 clients (out of 2)
[92mINFO 2025-07-14 18:14:53,486[0m:      aggregate_evaluate: received 2 results and 0 failures
[92mINFO 2025-07-14 18:14:53,490[0m:      
[92mINFO 2025-07-14 18:14:53,494[0m:      [ROUND 14]
[92mINFO 2025-07-14 18:14:53,495[0m:      configure_fit: strategy sampled 2 clients (out of 2)


[36m(ClientAppActor pid=13816)[0m Client's model: ETM_Client0
[36m(ClientAppActor pid=13816)[0m Epoch: 000 | Loss: 820.6788940429688


[92mINFO 2025-07-14 18:14:56,051[0m:      aggregate_fit: received 2 results and 0 failures


[36m(ClientAppActor pid=13816)[0m Client's model: ETM_Client1
[36m(ClientAppActor pid=13816)[0m Epoch: 000 | Loss: 804.0509033203125


[92mINFO 2025-07-14 18:14:56,141[0m:      configure_evaluate: strategy sampled 2 clients (out of 2)
[92mINFO 2025-07-14 18:14:57,810[0m:      aggregate_evaluate: received 2 results and 0 failures
[92mINFO 2025-07-14 18:14:57,813[0m:      
[92mINFO 2025-07-14 18:14:57,816[0m:      [ROUND 15]
[92mINFO 2025-07-14 18:14:57,820[0m:      configure_fit: strategy sampled 2 clients (out of 2)


[36m(ClientAppActor pid=13816)[0m Client's model: ETM_Client0
[36m(ClientAppActor pid=13816)[0m Epoch: 000 | Loss: 813.77734375


[92mINFO 2025-07-14 18:15:00,075[0m:      aggregate_fit: received 2 results and 0 failures
[92mINFO 2025-07-14 18:15:00,148[0m:      configure_evaluate: strategy sampled 2 clients (out of 2)


[36m(ClientAppActor pid=13816)[0m Client's model: ETM_Client1
[36m(ClientAppActor pid=13816)[0m Epoch: 000 | Loss: 797.7197265625


[92mINFO 2025-07-14 18:15:02,157[0m:      aggregate_evaluate: received 2 results and 0 failures
[92mINFO 2025-07-14 18:15:02,158[0m:      
[92mINFO 2025-07-14 18:15:02,159[0m:      [ROUND 16]
[92mINFO 2025-07-14 18:15:02,160[0m:      configure_fit: strategy sampled 2 clients (out of 2)


[36m(ClientAppActor pid=13816)[0m Client's model: ETM_Client0
[36m(ClientAppActor pid=13816)[0m Epoch: 000 | Loss: 808.0672607421875


[92mINFO 2025-07-14 18:15:05,360[0m:      aggregate_fit: received 2 results and 0 failures
[92mINFO 2025-07-14 18:15:05,444[0m:      configure_evaluate: strategy sampled 2 clients (out of 2)


[36m(ClientAppActor pid=13816)[0m Client's model: ETM_Client1
[36m(ClientAppActor pid=13816)[0m Epoch: 000 | Loss: 791.8588256835938


[92mINFO 2025-07-14 18:15:07,228[0m:      aggregate_evaluate: received 2 results and 0 failures
[92mINFO 2025-07-14 18:15:07,230[0m:      
[92mINFO 2025-07-14 18:15:07,232[0m:      [ROUND 17]
[92mINFO 2025-07-14 18:15:07,234[0m:      configure_fit: strategy sampled 2 clients (out of 2)


[36m(ClientAppActor pid=13816)[0m Client's model: ETM_Client0
[36m(ClientAppActor pid=13816)[0m Epoch: 000 | Loss: 803.269775390625


[92mINFO 2025-07-14 18:15:09,410[0m:      aggregate_fit: received 2 results and 0 failures
[92mINFO 2025-07-14 18:15:09,471[0m:      configure_evaluate: strategy sampled 2 clients (out of 2)


[36m(ClientAppActor pid=13816)[0m Client's model: ETM_Client1
[36m(ClientAppActor pid=13816)[0m Epoch: 000 | Loss: 787.1657104492188


[92mINFO 2025-07-14 18:15:11,186[0m:      aggregate_evaluate: received 2 results and 0 failures
[92mINFO 2025-07-14 18:15:11,188[0m:      
[92mINFO 2025-07-14 18:15:11,190[0m:      [ROUND 18]
[92mINFO 2025-07-14 18:15:11,191[0m:      configure_fit: strategy sampled 2 clients (out of 2)


[36m(ClientAppActor pid=13816)[0m Client's model: ETM_Client0
[36m(ClientAppActor pid=13816)[0m Epoch: 000 | Loss: 799.2686157226562


[92mINFO 2025-07-14 18:15:13,328[0m:      aggregate_fit: received 2 results and 0 failures
[92mINFO 2025-07-14 18:15:13,413[0m:      configure_evaluate: strategy sampled 2 clients (out of 2)


[36m(ClientAppActor pid=13816)[0m Client's model: ETM_Client1
[36m(ClientAppActor pid=13816)[0m Epoch: 000 | Loss: 782.9991455078125


[92mINFO 2025-07-14 18:15:15,116[0m:      aggregate_evaluate: received 2 results and 0 failures
[92mINFO 2025-07-14 18:15:15,118[0m:      
[92mINFO 2025-07-14 18:15:15,121[0m:      [ROUND 19]
[92mINFO 2025-07-14 18:15:15,125[0m:      configure_fit: strategy sampled 2 clients (out of 2)


[36m(ClientAppActor pid=13816)[0m Client's model: ETM_Client0
[36m(ClientAppActor pid=13816)[0m Epoch: 000 | Loss: 795.6741943359375


[92mINFO 2025-07-14 18:15:17,433[0m:      aggregate_fit: received 2 results and 0 failures


[36m(ClientAppActor pid=13816)[0m Client's model: ETM_Client1
[36m(ClientAppActor pid=13816)[0m Epoch: 000 | Loss: 779.8207397460938


[92mINFO 2025-07-14 18:15:17,521[0m:      configure_evaluate: strategy sampled 2 clients (out of 2)
[92mINFO 2025-07-14 18:15:19,255[0m:      aggregate_evaluate: received 2 results and 0 failures
[92mINFO 2025-07-14 18:15:19,256[0m:      
[92mINFO 2025-07-14 18:15:19,257[0m:      [ROUND 20]
[92mINFO 2025-07-14 18:15:19,259[0m:      configure_fit: strategy sampled 2 clients (out of 2)


[36m(ClientAppActor pid=13816)[0m Client's model: ETM_Client0
[36m(ClientAppActor pid=13816)[0m Epoch: 000 | Loss: 792.4217529296875


[92mINFO 2025-07-14 18:15:21,427[0m:      aggregate_fit: received 2 results and 0 failures
[92mINFO 2025-07-14 18:15:21,527[0m:      configure_evaluate: strategy sampled 2 clients (out of 2)


[36m(ClientAppActor pid=13816)[0m Client's model: ETM_Client1
[36m(ClientAppActor pid=13816)[0m Epoch: 000 | Loss: 776.630615234375
Saving round 20 aggregated_ndarrays...


[92mINFO 2025-07-14 18:15:23,882[0m:      aggregate_evaluate: received 2 results and 0 failures
[92mINFO 2025-07-14 18:15:23,883[0m:      
[92mINFO 2025-07-14 18:15:23,884[0m:      [ROUND 21]
[92mINFO 2025-07-14 18:15:23,885[0m:      configure_fit: strategy sampled 2 clients (out of 2)


[36m(ClientAppActor pid=13816)[0m Client's model: ETM_Client0
[36m(ClientAppActor pid=13816)[0m Epoch: 000 | Loss: 789.7561645507812


[92mINFO 2025-07-14 18:15:26,916[0m:      aggregate_fit: received 2 results and 0 failures
[92mINFO 2025-07-14 18:15:26,985[0m:      configure_evaluate: strategy sampled 2 clients (out of 2)


[36m(ClientAppActor pid=13816)[0m Client's model: ETM_Client1
[36m(ClientAppActor pid=13816)[0m Epoch: 000 | Loss: 773.943603515625
[36m(ClientAppActor pid=13816)[0m 


[92mINFO 2025-07-14 18:15:28,754[0m:      aggregate_evaluate: received 2 results and 0 failures
[92mINFO 2025-07-14 18:15:28,758[0m:      
[92mINFO 2025-07-14 18:15:28,761[0m:      [ROUND 22]
[92mINFO 2025-07-14 18:15:28,764[0m:      configure_fit: strategy sampled 2 clients (out of 2)


[36m(ClientAppActor pid=13816)[0m Client's model: ETM_Client0
[36m(ClientAppActor pid=13816)[0m Epoch: 000 | Loss: 787.2626953125


[92mINFO 2025-07-14 18:15:31,136[0m:      aggregate_fit: received 2 results and 0 failures
[92mINFO 2025-07-14 18:15:31,216[0m:      configure_evaluate: strategy sampled 2 clients (out of 2)


[36m(ClientAppActor pid=13816)[0m Client's model: ETM_Client1
[36m(ClientAppActor pid=13816)[0m Epoch: 000 | Loss: 771.6884155273438


[92mINFO 2025-07-14 18:15:33,007[0m:      aggregate_evaluate: received 2 results and 0 failures
[92mINFO 2025-07-14 18:15:33,009[0m:      
[92mINFO 2025-07-14 18:15:33,011[0m:      [ROUND 23]
[92mINFO 2025-07-14 18:15:33,013[0m:      configure_fit: strategy sampled 2 clients (out of 2)


[36m(ClientAppActor pid=13816)[0m Client's model: ETM_Client0
[36m(ClientAppActor pid=13816)[0m Epoch: 000 | Loss: 785.361083984375


[92mINFO 2025-07-14 18:15:35,386[0m:      aggregate_fit: received 2 results and 0 failures
[92mINFO 2025-07-14 18:15:35,460[0m:      configure_evaluate: strategy sampled 2 clients (out of 2)


[36m(ClientAppActor pid=13816)[0m Client's model: ETM_Client1
[36m(ClientAppActor pid=13816)[0m Epoch: 000 | Loss: 769.3698120117188


[92mINFO 2025-07-14 18:15:37,392[0m:      aggregate_evaluate: received 2 results and 0 failures
[92mINFO 2025-07-14 18:15:37,393[0m:      
[92mINFO 2025-07-14 18:15:37,395[0m:      [ROUND 24]
[92mINFO 2025-07-14 18:15:37,397[0m:      configure_fit: strategy sampled 2 clients (out of 2)


[36m(ClientAppActor pid=13816)[0m Client's model: ETM_Client0
[36m(ClientAppActor pid=13816)[0m Epoch: 000 | Loss: 783.0977783203125


[92mINFO 2025-07-14 18:15:39,570[0m:      aggregate_fit: received 2 results and 0 failures
[92mINFO 2025-07-14 18:15:39,639[0m:      configure_evaluate: strategy sampled 2 clients (out of 2)


[36m(ClientAppActor pid=13816)[0m Client's model: ETM_Client1
[36m(ClientAppActor pid=13816)[0m Epoch: 000 | Loss: 767.515625


[92mINFO 2025-07-14 18:15:41,312[0m:      aggregate_evaluate: received 2 results and 0 failures
[92mINFO 2025-07-14 18:15:41,314[0m:      
[92mINFO 2025-07-14 18:15:41,318[0m:      [ROUND 25]
[92mINFO 2025-07-14 18:15:41,321[0m:      configure_fit: strategy sampled 2 clients (out of 2)


[36m(ClientAppActor pid=13816)[0m Client's model: ETM_Client0
[36m(ClientAppActor pid=13816)[0m Epoch: 000 | Loss: 781.4594116210938


[92mINFO 2025-07-14 18:15:43,858[0m:      aggregate_fit: received 2 results and 0 failures


[36m(ClientAppActor pid=13816)[0m Client's model: ETM_Client1
[36m(ClientAppActor pid=13816)[0m Epoch: 000 | Loss: 765.8432006835938


[92mINFO 2025-07-14 18:15:43,986[0m:      configure_evaluate: strategy sampled 2 clients (out of 2)
[92mINFO 2025-07-14 18:15:45,927[0m:      aggregate_evaluate: received 2 results and 0 failures
[92mINFO 2025-07-14 18:15:45,929[0m:      
[92mINFO 2025-07-14 18:15:45,931[0m:      [ROUND 26]
[92mINFO 2025-07-14 18:15:45,932[0m:      configure_fit: strategy sampled 2 clients (out of 2)


[36m(ClientAppActor pid=13816)[0m Client's model: ETM_Client0
[36m(ClientAppActor pid=13816)[0m Epoch: 000 | Loss: 780.1004028320312


[92mINFO 2025-07-14 18:15:48,333[0m:      aggregate_fit: received 2 results and 0 failures


[36m(ClientAppActor pid=13816)[0m Client's model: ETM_Client1
[36m(ClientAppActor pid=13816)[0m Epoch: 000 | Loss: 764.081298828125


[92mINFO 2025-07-14 18:15:48,407[0m:      configure_evaluate: strategy sampled 2 clients (out of 2)
[92mINFO 2025-07-14 18:15:50,256[0m:      aggregate_evaluate: received 2 results and 0 failures
[92mINFO 2025-07-14 18:15:50,258[0m:      
[92mINFO 2025-07-14 18:15:50,260[0m:      [ROUND 27]
[92mINFO 2025-07-14 18:15:50,262[0m:      configure_fit: strategy sampled 2 clients (out of 2)


[36m(ClientAppActor pid=13816)[0m Client's model: ETM_Client0
[36m(ClientAppActor pid=13816)[0m Epoch: 000 | Loss: 778.6409912109375


[92mINFO 2025-07-14 18:15:52,634[0m:      aggregate_fit: received 2 results and 0 failures


[36m(ClientAppActor pid=13816)[0m Client's model: ETM_Client1
[36m(ClientAppActor pid=13816)[0m Epoch: 000 | Loss: 762.9615478515625


[92mINFO 2025-07-14 18:15:52,712[0m:      configure_evaluate: strategy sampled 2 clients (out of 2)
[92mINFO 2025-07-14 18:15:54,857[0m:      aggregate_evaluate: received 2 results and 0 failures
[92mINFO 2025-07-14 18:15:54,858[0m:      
[92mINFO 2025-07-14 18:15:54,859[0m:      [ROUND 28]
[92mINFO 2025-07-14 18:15:54,860[0m:      configure_fit: strategy sampled 2 clients (out of 2)


[36m(ClientAppActor pid=13816)[0m Client's model: ETM_Client0
[36m(ClientAppActor pid=13816)[0m Epoch: 000 | Loss: 777.0335693359375


[92mINFO 2025-07-14 18:15:58,379[0m:      aggregate_fit: received 2 results and 0 failures
[92mINFO 2025-07-14 18:15:58,509[0m:      configure_evaluate: strategy sampled 2 clients (out of 2)


[36m(ClientAppActor pid=13816)[0m Client's model: ETM_Client1
[36m(ClientAppActor pid=13816)[0m Epoch: 000 | Loss: 761.2867431640625


[92mINFO 2025-07-14 18:16:00,435[0m:      aggregate_evaluate: received 2 results and 0 failures
[92mINFO 2025-07-14 18:16:00,436[0m:      
[92mINFO 2025-07-14 18:16:00,437[0m:      [ROUND 29]
[92mINFO 2025-07-14 18:16:00,438[0m:      configure_fit: strategy sampled 2 clients (out of 2)


[36m(ClientAppActor pid=13816)[0m Client's model: ETM_Client0
[36m(ClientAppActor pid=13816)[0m Epoch: 000 | Loss: 775.8414306640625


[92mINFO 2025-07-14 18:16:03,839[0m:      aggregate_fit: received 2 results and 0 failures


[36m(ClientAppActor pid=13816)[0m Client's model: ETM_Client1
[36m(ClientAppActor pid=13816)[0m Epoch: 000 | Loss: 760.3041381835938


[92mINFO 2025-07-14 18:16:03,955[0m:      configure_evaluate: strategy sampled 2 clients (out of 2)
[92mINFO 2025-07-14 18:16:06,243[0m:      aggregate_evaluate: received 2 results and 0 failures
[92mINFO 2025-07-14 18:16:06,244[0m:      
[92mINFO 2025-07-14 18:16:06,245[0m:      [ROUND 30]
[92mINFO 2025-07-14 18:16:06,245[0m:      configure_fit: strategy sampled 2 clients (out of 2)


[36m(ClientAppActor pid=13816)[0m Client's model: ETM_Client0
[36m(ClientAppActor pid=13816)[0m Epoch: 000 | Loss: 774.8886108398438


[92mINFO 2025-07-14 18:16:09,586[0m:      aggregate_fit: received 2 results and 0 failures


[36m(ClientAppActor pid=13816)[0m Client's model: ETM_Client1
[36m(ClientAppActor pid=13816)[0m Epoch: 000 | Loss: 759.038818359375


[92mINFO 2025-07-14 18:16:09,721[0m:      configure_evaluate: strategy sampled 2 clients (out of 2)


Saving round 30 aggregated_ndarrays...


[92mINFO 2025-07-14 18:16:11,746[0m:      aggregate_evaluate: received 2 results and 0 failures
[92mINFO 2025-07-14 18:16:11,763[0m:      
[92mINFO 2025-07-14 18:16:11,763[0m:      [SUMMARY]
[92mINFO 2025-07-14 18:16:11,764[0m:      Run finished 30 round(s) in 140.08s
[92mINFO 2025-07-14 18:16:11,765[0m:      	History (loss, distributed):
[92mINFO 2025-07-14 18:16:11,766[0m:      		round 1: -1.0
[92mINFO 2025-07-14 18:16:11,766[0m:      		round 2: -1.0
[92mINFO 2025-07-14 18:16:11,767[0m:      		round 3: -1.0
[92mINFO 2025-07-14 18:16:11,767[0m:      		round 4: -1.0
[92mINFO 2025-07-14 18:16:11,768[0m:      		round 5: -1.0
[92mINFO 2025-07-14 18:16:11,768[0m:      		round 6: -1.0
[92mINFO 2025-07-14 18:16:11,768[0m:      		round 7: -1.0
[92mINFO 2025-07-14 18:16:11,769[0m:      		round 8: -1.0
[92mINFO 2025-07-14 18:16:11,769[0m:      		round 9: -1.0
[92mINFO 2025-07-14 18:16:11,770[0m:      		round 10: -1.0
[92mINFO 2025-07-14 18:16:11,770[0m:      		ro

In [None]:
from utils._utils import get_top_words
net = ETM(len(vocab))
test = get_latest_server_model(net)
beta = net.get_beta().detach().cpu().numpy()
topwords = get_top_words(beta, vocab, 15, verbose=True)

Loading pre-trained model from:  model_parameters\model_round_20.npz
Topic 0: mac later keys started pittsburgh date health dod crime previous development engineering involved follow media
Topic 1: state technology clipper remember important comes network looks study usually ideas york religious second sounds
Topic 2: data end interested went california reference genocide laboratory town visual rangers prevent turbo florida closer
Topic 3: second let away ask similar series note error press ground player food safety generally contains
Topic 4: cost machine future problems section speak higher device launch response escrow modern army received performance
Topic 5: lot research phone trying paul april states force bus religion atheists common condition reserve easily
Topic 6: version getting value bike asked cards rules scott died audio completely hey truth listen shown
Topic 7: think way space law access live known ones care policy chris total education issue ken
Topic 8: car image run 

In [None]:
trainer = BasicTrainer(net, datasets[0])

In [None]:
net.to(DEVICE)

ETM(
  (encoder1): Sequential(
    (0): Linear(in_features=5000, out_features=800, bias=True)
    (1): ReLU()
    (2): Linear(in_features=800, out_features=800, bias=True)
    (3): ReLU()
    (4): Dropout(p=0.0, inplace=False)
  )
  (fc21): Linear(in_features=800, out_features=50, bias=True)
  (fc22): Linear(in_features=800, out_features=50, bias=True)
)

In [None]:
########################### test new documents ####################################
from data.preprocess import Preprocess

preprocess = Preprocess()

new_docs = [
    "This is a new document about space, including words like space, satellite, launch, orbit.",
    "This is a new document about Microsoft Windows, including words like windows, files, dos."
]

parsed_new_docs, new_bow = preprocess.parse(new_docs, vocab)
print(new_bow.shape)

print(new_bow.toarray())
input = torch.as_tensor(new_bow.toarray(), device="cuda").float()
print(input)
new_theta = trainer.test(input)

print(new_theta.argmax(1))
for x in new_theta.argmax(1):
    print(topwords[x])

parsing texts: 100%|██████████| 2/2 [00:00<00:00, 1971.93it/s]


(2, 5000)
[[0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]]
tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]], device='cuda:0')
[ 7 15]
think way space law access live known ones care policy chris total education issue ken
god windows read bible human war current code distribution example division open offer needed knows
