In [84]:
%cd .

d:\MachineLearning\federated_vae\main


In [85]:
from collections import OrderedDict
from typing import List, Tuple, Union, Optional

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from datasets.utils.logging import disable_progress_bar
from torch.utils.data import DataLoader

import flwr as fl
import numpy as np
import torch
import glob
import os
from flwr.common import ndarrays_to_parameters
from flwr.client import Client, ClientApp, NumPyClient
from flwr.common import Metrics, Context, FitRes, Parameters, Scalar
from flwr.server import ServerApp, ServerConfig, ServerAppComponents
from flwr.server.client_proxy import ClientProxy
from flwr.server.strategy import FedAvg
from flwr.simulation import run_simulation
from flwr_datasets import FederatedDataset

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Training on {DEVICE}")
print(f"Flower {fl.__version__} / PyTorch {torch.__version__}")
disable_progress_bar()

Training on cuda
Flower 1.19.0 / PyTorch 2.5.1+cu121


In [86]:
NUM_CLIENTS = 2
BATCH_SIZE = 256
NUM_ROUNDS = 30

from test_flwr import get_all_vocab, split_data
vocab = get_all_vocab(["../data/20NG"])
datasets = split_data(dir = "../data/20NG", num_split=NUM_CLIENTS, vocab = vocab, batch_size= BATCH_SIZE)

train_size:  11314
test_size:  7532
vocab_size:  5000
average length: 110.543


loading train texts: 100%|██████████| 5657/5657 [00:00<00:00, 8107.73it/s]
parsing texts: 100%|██████████| 5657/5657 [00:00<00:00, 9776.49it/s] 
loading train texts: 100%|██████████| 5657/5657 [00:00<00:00, 8314.55it/s]
parsing texts: 100%|██████████| 5657/5657 [00:00<00:00, 10730.31it/s]


In [None]:
from model.ETM import ETM
from trainer.basic_trainer import BasicTrainer



In [89]:
def set_parameters(net, parameters: List[np.ndarray]):
    params_dict = zip(net.state_dict().keys(), parameters)
    state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
    net.load_state_dict(state_dict, strict=True)


def get_parameters(net) -> List[np.ndarray]:
    return [val.cpu().numpy() for _, val in net.state_dict().items()]

In [90]:
from data.basic_dataset import RawDataset
class FlowerClient(NumPyClient):
  def __init__(self, net, dataset : RawDataset, id):
    self.net = net
    self.dataset = dataset
    self.trainer = BasicTrainer(net, dataset, epochs = 1, log_interval=10, device = DEVICE, save_model = True, save_interval=NUM_ROUNDS)
    self.id = id
    self.save_dir = "model_parameters/"
    self.round_id = 0
    self.total_round = NUM_ROUNDS

  # return the current local model parameters
  def get_parameters(self, config):
    return get_parameters(self.net)

  # receive global parameter, train, return updated model to server
  def fit(self, parameters, config):
    set_parameters(self.net, parameters)
    self.trainer.train(model_name = f"ETM_Client{self.id}")

    return get_parameters(self.net), len(self.dataset.train_texts), {}

  # receive global parameter, evaluate model from local's data, return the evaluation result
  def evaluate(self, parameters, config):
    set_parameters(self.net, parameters)
    loss, acc = -1, -1
    return float(loss), 1, {"accuracy":float(acc)}


test = FlowerClient(ETM(len(vocab)), datasets[0], 0)

In [91]:
def client_fn(context: Context) -> Client:
    """Create a Flower client representing a single organization."""

    # Load model
    net = ETM(len(vocab)).to(DEVICE)

    # Load data (CIFAR-10)
    # Note: each client gets a different trainloader/valloader, so each client
    # will train and evaluate on their own unique data partition
    # Read the node_config to fetch data partition associated to this node
    partition_id = context.node_config["partition-id"]
    dataset = datasets[partition_id]

    # Create a single Flower client representing a single organization
    # FlowerClient is a subclass of NumPyClient, so we need to call .to_client()
    # to convert it to a subclass of `flwr.client.Client`
    return FlowerClient(net, dataset, partition_id).to_client()


# Create the ClientApp
client = ClientApp(client_fn=client_fn)

In [None]:
class SaveModelStrategy(fl.server.strategy.FedAvg):
    def aggregate_fit(
        self,
        server_round: int,
        results: list[tuple[fl.server.client_proxy.ClientProxy, fl.common.FitRes]],
        failures: list[Union[tuple[ClientProxy, FitRes], BaseException]],
    ) -> tuple[Optional[Parameters], dict[str, Scalar]]:

        # Call aggregate_fit from base class (FedAvg) to aggregate parameters and metrics
        aggregated_parameters, aggregated_metrics = super().aggregate_fit(
            server_round, results, failures
        )

        if aggregated_parameters is not None:
            # Convert `Parameters` to `list[np.ndarray]`
            aggregated_ndarrays: list[np.ndarray] = fl.common.parameters_to_ndarrays(
                aggregated_parameters
            )

            # Save aggregated_ndarrays to disk
            if server_round % 10 == 0:
                print(f"Saving round {server_round} aggregated_ndarrays...")
                np.savez(f"model_parameters/model_round_{server_round}.npz", *aggregated_ndarrays)

        return aggregated_parameters, aggregated_metrics


# Create strategy and pass into ServerApp
def server_fn(context):
    strategy = SaveModelStrategy(
        fraction_fit=1.0,
        fraction_evaluate=0.5,
        min_fit_clients=NUM_CLIENTS,
        min_available_clients=NUM_CLIENTS,
    )
    config = ServerConfig(num_rounds=NUM_ROUNDS)
    return ServerAppComponents(strategy=strategy, config=config)


server = ServerApp(server_fn=server_fn)

In [94]:
def get_latest_server_model(net):
    list_of_files = [fname for fname in glob.glob("model_parameters/model_round_*")]
    latest_round_file = max(list_of_files, key=os.path.getctime)
    print("Loading pre-trained model from: ", latest_round_file)
    
    # Load NumPy arrays from .npz file
    with np.load(latest_round_file) as data:
        arrays = [data[f'arr_{i}'] for i in range(len(data.files))]
    
    # Convert to PyTorch state_dict
    state_dict = {k: torch.from_numpy(v) for k, v in zip(net.state_dict().keys(), arrays)}
    net.load_state_dict(state_dict)
    
    # Convert to Flower Parameters
    state_dict_ndarrays = [v.cpu().numpy() for v in net.state_dict().values()]
    parameters = fl.common.ndarrays_to_parameters(state_dict_ndarrays)
    return parameters

In [95]:
# Specify the resources each of your clients need
# By default, each client will be allocated 1x CPU and 0x GPUs
backend_config = {"client_resources": {"num_cpus": 1, "num_gpus": 0.0}}

# When running on GPU, assign an entire GPU for each client
if DEVICE == "cuda":
    backend_config = {"client_resources": {"num_cpus": 1, "num_gpus": 1.0}}
    # Refer to our Flower framework documentation for more details about Flower simulations
    # and how to set up the `backend_config`

In [96]:
# Run simulation
run_simulation(
    server_app=server,
    client_app=client,
    num_supernodes=NUM_CLIENTS,
    backend_config=backend_config,
    verbose_logging=True
)

[94mDEBUG 2025-07-13 00:28:06,339[0m:     Asyncio event loop already running.


[94mDEBUG 2025-07-13 00:28:06,362[0m:     Logger propagate set to False
[94mDEBUG 2025-07-13 00:28:06,363[0m:     Pre-registering run with id 3275487587500647736
[94mDEBUG 2025-07-13 00:28:06,365[0m:     Using InMemoryState
[94mDEBUG 2025-07-13 00:28:06,366[0m:     Using InMemoryState
[92mINFO 2025-07-13 00:28:06,369[0m:      Starting Flower ServerApp, config: num_rounds=30, no round_timeout
[92mINFO 2025-07-13 00:28:06,370[0m:      
[94mDEBUG 2025-07-13 00:28:06,371[0m:     Using InMemoryState
[94mDEBUG 2025-07-13 00:28:06,373[0m:     Registered 2 nodes
[92mINFO 2025-07-13 00:28:06,374[0m:      [INIT]
[94mDEBUG 2025-07-13 00:28:06,375[0m:     Supported backends: ['ray']
[92mINFO 2025-07-13 00:28:06,375[0m:      Requesting initial parameters from one random client
[94mDEBUG 2025-07-13 00:28:06,376[0m:     Initialising: RayBackend
[94mDEBUG 2025-07-13 00:28:06,378[0m:     Backend config: {'client_resources': {'num_cpus': 1, 'num_gpus': 1.0}, 'init_args': {}, 'a

[36m(ClientAppActor pid=21468)[0m Client's model: ETM_Client0
[36m(ClientAppActor pid=21468)[0m Epoch: 000 | Loss: 1957.7569580078125


[92mINFO 2025-07-13 00:28:35,754[0m:      aggregate_fit: received 2 results and 0 failures
[92mINFO 2025-07-13 00:28:35,877[0m:      configure_evaluate: strategy sampled 2 clients (out of 2)


[36m(ClientAppActor pid=21468)[0m Client's model: ETM_Client1
[36m(ClientAppActor pid=21468)[0m Epoch: 000 | Loss: 1905.6146240234375
Saving round 1 aggregated_ndarrays...


[92mINFO 2025-07-13 00:28:38,031[0m:      aggregate_evaluate: received 2 results and 0 failures
[92mINFO 2025-07-13 00:28:38,032[0m:      
[92mINFO 2025-07-13 00:28:38,034[0m:      [ROUND 2]
[92mINFO 2025-07-13 00:28:38,035[0m:      configure_fit: strategy sampled 2 clients (out of 2)


[36m(ClientAppActor pid=21468)[0m Client's model: ETM_Client0
[36m(ClientAppActor pid=21468)[0m Epoch: 000 | Loss: 1606.6868896484375


[92mINFO 2025-07-13 00:28:40,698[0m:      aggregate_fit: received 2 results and 0 failures


[36m(ClientAppActor pid=21468)[0m Client's model: ETM_Client1
[36m(ClientAppActor pid=21468)[0m Epoch: 000 | Loss: 1564.1192626953125


[92mINFO 2025-07-13 00:28:40,826[0m:      configure_evaluate: strategy sampled 2 clients (out of 2)


Saving round 2 aggregated_ndarrays...


[92mINFO 2025-07-13 00:28:43,706[0m:      aggregate_evaluate: received 2 results and 0 failures
[92mINFO 2025-07-13 00:28:43,706[0m:      
[92mINFO 2025-07-13 00:28:43,707[0m:      [ROUND 3]
[92mINFO 2025-07-13 00:28:43,708[0m:      configure_fit: strategy sampled 2 clients (out of 2)


[36m(ClientAppActor pid=21468)[0m Client's model: ETM_Client0
[36m(ClientAppActor pid=21468)[0m Epoch: 000 | Loss: 1377.3734130859375


[92mINFO 2025-07-13 00:28:46,419[0m:      aggregate_fit: received 2 results and 0 failures


[36m(ClientAppActor pid=21468)[0m Client's model: ETM_Client1
[36m(ClientAppActor pid=21468)[0m Epoch: 000 | Loss: 1341.1683349609375


[92mINFO 2025-07-13 00:28:46,548[0m:      configure_evaluate: strategy sampled 2 clients (out of 2)


Saving round 3 aggregated_ndarrays...


[92mINFO 2025-07-13 00:28:48,580[0m:      aggregate_evaluate: received 2 results and 0 failures
[92mINFO 2025-07-13 00:28:48,582[0m:      
[92mINFO 2025-07-13 00:28:48,582[0m:      [ROUND 4]
[92mINFO 2025-07-13 00:28:48,584[0m:      configure_fit: strategy sampled 2 clients (out of 2)


[36m(ClientAppActor pid=21468)[0m Client's model: ETM_Client0
[36m(ClientAppActor pid=21468)[0m Epoch: 000 | Loss: 1219.7779541015625


[92mINFO 2025-07-13 00:28:51,035[0m:      aggregate_fit: received 2 results and 0 failures
[92mINFO 2025-07-13 00:28:51,132[0m:      configure_evaluate: strategy sampled 2 clients (out of 2)


[36m(ClientAppActor pid=21468)[0m Client's model: ETM_Client1
[36m(ClientAppActor pid=21468)[0m Epoch: 000 | Loss: 1194.0657958984375
Saving round 4 aggregated_ndarrays...


[92mINFO 2025-07-13 00:28:53,030[0m:      aggregate_evaluate: received 2 results and 0 failures
[92mINFO 2025-07-13 00:28:53,031[0m:      
[92mINFO 2025-07-13 00:28:53,031[0m:      [ROUND 5]
[92mINFO 2025-07-13 00:28:53,033[0m:      configure_fit: strategy sampled 2 clients (out of 2)


[36m(ClientAppActor pid=21468)[0m Client's model: ETM_Client0
[36m(ClientAppActor pid=21468)[0m Epoch: 000 | Loss: 1110.7203369140625


[92mINFO 2025-07-13 00:28:55,496[0m:      aggregate_fit: received 2 results and 0 failures


[36m(ClientAppActor pid=21468)[0m Client's model: ETM_Client1
[36m(ClientAppActor pid=21468)[0m Epoch: 000 | Loss: 1085.0416259765625


[92mINFO 2025-07-13 00:28:55,598[0m:      configure_evaluate: strategy sampled 2 clients (out of 2)


Saving round 5 aggregated_ndarrays...


[92mINFO 2025-07-13 00:28:57,650[0m:      aggregate_evaluate: received 2 results and 0 failures
[92mINFO 2025-07-13 00:28:57,651[0m:      
[92mINFO 2025-07-13 00:28:57,651[0m:      [ROUND 6]
[92mINFO 2025-07-13 00:28:57,652[0m:      configure_fit: strategy sampled 2 clients (out of 2)


[36m(ClientAppActor pid=21468)[0m Client's model: ETM_Client0
[36m(ClientAppActor pid=21468)[0m Epoch: 000 | Loss: 1030.103271484375


[92mINFO 2025-07-13 00:29:00,084[0m:      aggregate_fit: received 2 results and 0 failures


[36m(ClientAppActor pid=21468)[0m Client's model: ETM_Client1
[36m(ClientAppActor pid=21468)[0m Epoch: 000 | Loss: 1006.30419921875
Saving round 6 aggregated_ndarrays...


[92mINFO 2025-07-13 00:29:00,191[0m:      configure_evaluate: strategy sampled 2 clients (out of 2)
[92mINFO 2025-07-13 00:29:02,401[0m:      aggregate_evaluate: received 2 results and 0 failures
[92mINFO 2025-07-13 00:29:02,402[0m:      
[92mINFO 2025-07-13 00:29:02,403[0m:      [ROUND 7]
[92mINFO 2025-07-13 00:29:02,404[0m:      configure_fit: strategy sampled 2 clients (out of 2)


[36m(ClientAppActor pid=21468)[0m Client's model: ETM_Client0
[36m(ClientAppActor pid=21468)[0m Epoch: 000 | Loss: 970.5216674804688


[92mINFO 2025-07-13 00:29:04,809[0m:      aggregate_fit: received 2 results and 0 failures


[36m(ClientAppActor pid=21468)[0m Client's model: ETM_Client1
[36m(ClientAppActor pid=21468)[0m Epoch: 000 | Loss: 949.8295288085938
Saving round 7 aggregated_ndarrays...


[92mINFO 2025-07-13 00:29:04,904[0m:      configure_evaluate: strategy sampled 2 clients (out of 2)
[92mINFO 2025-07-13 00:29:06,750[0m:      aggregate_evaluate: received 2 results and 0 failures
[92mINFO 2025-07-13 00:29:06,752[0m:      
[92mINFO 2025-07-13 00:29:06,753[0m:      [ROUND 8]
[92mINFO 2025-07-13 00:29:06,754[0m:      configure_fit: strategy sampled 2 clients (out of 2)


[36m(ClientAppActor pid=21468)[0m Client's model: ETM_Client0
[36m(ClientAppActor pid=21468)[0m Epoch: 000 | Loss: 928.3550415039062


[92mINFO 2025-07-13 00:29:09,263[0m:      aggregate_fit: received 2 results and 0 failures
[92mINFO 2025-07-13 00:29:09,364[0m:      configure_evaluate: strategy sampled 2 clients (out of 2)


[36m(ClientAppActor pid=21468)[0m Client's model: ETM_Client1
[36m(ClientAppActor pid=21468)[0m Epoch: 000 | Loss: 908.63623046875
Saving round 8 aggregated_ndarrays...


[92mINFO 2025-07-13 00:29:11,798[0m:      aggregate_evaluate: received 2 results and 0 failures
[92mINFO 2025-07-13 00:29:11,798[0m:      
[92mINFO 2025-07-13 00:29:11,800[0m:      [ROUND 9]
[92mINFO 2025-07-13 00:29:11,801[0m:      configure_fit: strategy sampled 2 clients (out of 2)


[36m(ClientAppActor pid=21468)[0m Client's model: ETM_Client0
[36m(ClientAppActor pid=21468)[0m Epoch: 000 | Loss: 897.1045532226562


[92mINFO 2025-07-13 00:29:14,627[0m:      aggregate_fit: received 2 results and 0 failures
[92mINFO 2025-07-13 00:29:14,728[0m:      configure_evaluate: strategy sampled 2 clients (out of 2)


[36m(ClientAppActor pid=21468)[0m Client's model: ETM_Client1
[36m(ClientAppActor pid=21468)[0m Epoch: 000 | Loss: 878.5476684570312
Saving round 9 aggregated_ndarrays...


[92mINFO 2025-07-13 00:29:16,704[0m:      aggregate_evaluate: received 2 results and 0 failures
[92mINFO 2025-07-13 00:29:16,705[0m:      
[92mINFO 2025-07-13 00:29:16,706[0m:      [ROUND 10]
[92mINFO 2025-07-13 00:29:16,706[0m:      configure_fit: strategy sampled 2 clients (out of 2)


[36m(ClientAppActor pid=21468)[0m Client's model: ETM_Client0
[36m(ClientAppActor pid=21468)[0m Epoch: 000 | Loss: 873.4866333007812


[92mINFO 2025-07-13 00:29:19,108[0m:      aggregate_fit: received 2 results and 0 failures
[92mINFO 2025-07-13 00:29:19,208[0m:      configure_evaluate: strategy sampled 2 clients (out of 2)


[36m(ClientAppActor pid=21468)[0m Client's model: ETM_Client1
[36m(ClientAppActor pid=21468)[0m Epoch: 000 | Loss: 856.02392578125
Saving round 10 aggregated_ndarrays...


[92mINFO 2025-07-13 00:29:21,382[0m:      aggregate_evaluate: received 2 results and 0 failures
[92mINFO 2025-07-13 00:29:21,383[0m:      
[92mINFO 2025-07-13 00:29:21,384[0m:      [ROUND 11]
[92mINFO 2025-07-13 00:29:21,385[0m:      configure_fit: strategy sampled 2 clients (out of 2)


[36m(ClientAppActor pid=21468)[0m Client's model: ETM_Client0
[36m(ClientAppActor pid=21468)[0m Epoch: 000 | Loss: 855.5550537109375


[92mINFO 2025-07-13 00:29:23,794[0m:      aggregate_fit: received 2 results and 0 failures
[92mINFO 2025-07-13 00:29:23,894[0m:      configure_evaluate: strategy sampled 2 clients (out of 2)


[36m(ClientAppActor pid=21468)[0m Client's model: ETM_Client1
[36m(ClientAppActor pid=21468)[0m Epoch: 000 | Loss: 838.1940307617188
Saving round 11 aggregated_ndarrays...


[92mINFO 2025-07-13 00:29:25,870[0m:      aggregate_evaluate: received 2 results and 0 failures
[92mINFO 2025-07-13 00:29:25,871[0m:      
[92mINFO 2025-07-13 00:29:25,872[0m:      [ROUND 12]
[92mINFO 2025-07-13 00:29:25,873[0m:      configure_fit: strategy sampled 2 clients (out of 2)


[36m(ClientAppActor pid=21468)[0m Client's model: ETM_Client0
[36m(ClientAppActor pid=21468)[0m Epoch: 000 | Loss: 841.649658203125


[92mINFO 2025-07-13 00:29:28,614[0m:      aggregate_fit: received 2 results and 0 failures


[36m(ClientAppActor pid=21468)[0m Client's model: ETM_Client1
[36m(ClientAppActor pid=21468)[0m Epoch: 000 | Loss: 824.5891723632812


[92mINFO 2025-07-13 00:29:28,749[0m:      configure_evaluate: strategy sampled 2 clients (out of 2)


Saving round 12 aggregated_ndarrays...


[92mINFO 2025-07-13 00:29:31,010[0m:      aggregate_evaluate: received 2 results and 0 failures
[92mINFO 2025-07-13 00:29:31,011[0m:      
[92mINFO 2025-07-13 00:29:31,012[0m:      [ROUND 13]
[92mINFO 2025-07-13 00:29:31,013[0m:      configure_fit: strategy sampled 2 clients (out of 2)


[36m(ClientAppActor pid=21468)[0m Client's model: ETM_Client0
[36m(ClientAppActor pid=21468)[0m Epoch: 000 | Loss: 830.05615234375


[92mINFO 2025-07-13 00:29:34,096[0m:      aggregate_fit: received 2 results and 0 failures


[36m(ClientAppActor pid=21468)[0m Client's model: ETM_Client1
[36m(ClientAppActor pid=21468)[0m Epoch: 000 | Loss: 813.4004516601562
Saving round 13 aggregated_ndarrays...


[92mINFO 2025-07-13 00:29:34,196[0m:      configure_evaluate: strategy sampled 2 clients (out of 2)
[92mINFO 2025-07-13 00:29:36,346[0m:      aggregate_evaluate: received 2 results and 0 failures
[92mINFO 2025-07-13 00:29:36,347[0m:      
[92mINFO 2025-07-13 00:29:36,347[0m:      [ROUND 14]
[92mINFO 2025-07-13 00:29:36,348[0m:      configure_fit: strategy sampled 2 clients (out of 2)


[36m(ClientAppActor pid=21468)[0m Client's model: ETM_Client0
[36m(ClientAppActor pid=21468)[0m Epoch: 000 | Loss: 821.3988037109375


[92mINFO 2025-07-13 00:29:38,759[0m:      aggregate_fit: received 2 results and 0 failures
[92mINFO 2025-07-13 00:29:38,858[0m:      configure_evaluate: strategy sampled 2 clients (out of 2)


[36m(ClientAppActor pid=21468)[0m Client's model: ETM_Client1
[36m(ClientAppActor pid=21468)[0m Epoch: 000 | Loss: 804.63525390625
Saving round 14 aggregated_ndarrays...


[92mINFO 2025-07-13 00:29:40,833[0m:      aggregate_evaluate: received 2 results and 0 failures
[92mINFO 2025-07-13 00:29:40,834[0m:      
[92mINFO 2025-07-13 00:29:40,835[0m:      [ROUND 15]
[92mINFO 2025-07-13 00:29:40,836[0m:      configure_fit: strategy sampled 2 clients (out of 2)


[36m(ClientAppActor pid=21468)[0m Client's model: ETM_Client0
[36m(ClientAppActor pid=21468)[0m Epoch: 000 | Loss: 814.098876953125
[36m(ClientAppActor pid=21468)[0m Client's model: ETM_Client1
[36m(ClientAppActor pid=21468)[0m Epoch: 000 | Loss: 797.4954833984375


[92mINFO 2025-07-13 00:29:44,160[0m:      aggregate_fit: received 2 results and 0 failures
[92mINFO 2025-07-13 00:29:44,270[0m:      configure_evaluate: strategy sampled 2 clients (out of 2)


Saving round 15 aggregated_ndarrays...


[92mINFO 2025-07-13 00:29:46,385[0m:      aggregate_evaluate: received 2 results and 0 failures
[92mINFO 2025-07-13 00:29:46,386[0m:      
[92mINFO 2025-07-13 00:29:46,386[0m:      [ROUND 16]
[92mINFO 2025-07-13 00:29:46,387[0m:      configure_fit: strategy sampled 2 clients (out of 2)


[36m(ClientAppActor pid=21468)[0m Client's model: ETM_Client0
[36m(ClientAppActor pid=21468)[0m Epoch: 000 | Loss: 807.8772583007812


[92mINFO 2025-07-13 00:29:48,978[0m:      aggregate_fit: received 2 results and 0 failures


[36m(ClientAppActor pid=21468)[0m Client's model: ETM_Client1
[36m(ClientAppActor pid=21468)[0m Epoch: 000 | Loss: 791.816162109375
Saving round 16 aggregated_ndarrays...


[92mINFO 2025-07-13 00:29:49,089[0m:      configure_evaluate: strategy sampled 2 clients (out of 2)
[92mINFO 2025-07-13 00:29:51,234[0m:      aggregate_evaluate: received 2 results and 0 failures
[92mINFO 2025-07-13 00:29:51,234[0m:      
[92mINFO 2025-07-13 00:29:51,235[0m:      [ROUND 17]
[92mINFO 2025-07-13 00:29:51,235[0m:      configure_fit: strategy sampled 2 clients (out of 2)


[36m(ClientAppActor pid=21468)[0m Client's model: ETM_Client0
[36m(ClientAppActor pid=21468)[0m Epoch: 000 | Loss: 803.079833984375


[92mINFO 2025-07-13 00:29:53,790[0m:      aggregate_fit: received 2 results and 0 failures
[92mINFO 2025-07-13 00:29:53,889[0m:      configure_evaluate: strategy sampled 2 clients (out of 2)


[36m(ClientAppActor pid=21468)[0m Client's model: ETM_Client1
[36m(ClientAppActor pid=21468)[0m Epoch: 000 | Loss: 786.9244995117188
Saving round 17 aggregated_ndarrays...


[92mINFO 2025-07-13 00:29:56,260[0m:      aggregate_evaluate: received 2 results and 0 failures
[92mINFO 2025-07-13 00:29:56,262[0m:      
[92mINFO 2025-07-13 00:29:56,264[0m:      [ROUND 18]
[92mINFO 2025-07-13 00:29:56,265[0m:      configure_fit: strategy sampled 2 clients (out of 2)


[36m(ClientAppActor pid=21468)[0m Client's model: ETM_Client0
[36m(ClientAppActor pid=21468)[0m Epoch: 000 | Loss: 799.2915649414062


[92mINFO 2025-07-13 00:29:59,814[0m:      aggregate_fit: received 2 results and 0 failures


[36m(ClientAppActor pid=21468)[0m Client's model: ETM_Client1
[36m(ClientAppActor pid=21468)[0m Epoch: 000 | Loss: 783.129638671875
Saving round 18 aggregated_ndarrays...


[92mINFO 2025-07-13 00:29:59,930[0m:      configure_evaluate: strategy sampled 2 clients (out of 2)
[92mINFO 2025-07-13 00:30:02,927[0m:      aggregate_evaluate: received 2 results and 0 failures
[92mINFO 2025-07-13 00:30:02,928[0m:      
[92mINFO 2025-07-13 00:30:02,929[0m:      [ROUND 19]
[92mINFO 2025-07-13 00:30:02,930[0m:      configure_fit: strategy sampled 2 clients (out of 2)


[36m(ClientAppActor pid=21468)[0m Client's model: ETM_Client0
[36m(ClientAppActor pid=21468)[0m Epoch: 000 | Loss: 795.3217163085938


[92mINFO 2025-07-13 00:30:05,530[0m:      aggregate_fit: received 2 results and 0 failures
[92mINFO 2025-07-13 00:30:05,621[0m:      configure_evaluate: strategy sampled 2 clients (out of 2)


[36m(ClientAppActor pid=21468)[0m Client's model: ETM_Client1
[36m(ClientAppActor pid=21468)[0m Epoch: 000 | Loss: 779.5692749023438
Saving round 19 aggregated_ndarrays...


[92mINFO 2025-07-13 00:30:08,206[0m:      aggregate_evaluate: received 2 results and 0 failures
[92mINFO 2025-07-13 00:30:08,207[0m:      
[92mINFO 2025-07-13 00:30:08,207[0m:      [ROUND 20]
[92mINFO 2025-07-13 00:30:08,208[0m:      configure_fit: strategy sampled 2 clients (out of 2)


[36m(ClientAppActor pid=21468)[0m Client's model: ETM_Client0
[36m(ClientAppActor pid=21468)[0m Epoch: 000 | Loss: 792.4185180664062


[92mINFO 2025-07-13 00:30:11,170[0m:      aggregate_fit: received 2 results and 0 failures


[36m(ClientAppActor pid=21468)[0m Client's model: ETM_Client1
[36m(ClientAppActor pid=21468)[0m Epoch: 000 | Loss: 776.4039916992188
Saving round 20 aggregated_ndarrays...


[92mINFO 2025-07-13 00:30:12,174[0m:      configure_evaluate: strategy sampled 2 clients (out of 2)
[92mINFO 2025-07-13 00:30:15,621[0m:      aggregate_evaluate: received 2 results and 0 failures
[92mINFO 2025-07-13 00:30:15,623[0m:      
[92mINFO 2025-07-13 00:30:15,624[0m:      [ROUND 21]
[92mINFO 2025-07-13 00:30:15,626[0m:      configure_fit: strategy sampled 2 clients (out of 2)


[36m(ClientAppActor pid=21468)[0m Client's model: ETM_Client0
[36m(ClientAppActor pid=21468)[0m Epoch: 000 | Loss: 789.4107055664062


[92mINFO 2025-07-13 00:30:18,578[0m:      aggregate_fit: received 2 results and 0 failures


[36m(ClientAppActor pid=21468)[0m Client's model: ETM_Client1
[36m(ClientAppActor pid=21468)[0m Epoch: 000 | Loss: 773.82763671875
Saving round 21 aggregated_ndarrays...


[92mINFO 2025-07-13 00:30:18,687[0m:      configure_evaluate: strategy sampled 2 clients (out of 2)
[92mINFO 2025-07-13 00:30:20,949[0m:      aggregate_evaluate: received 2 results and 0 failures
[92mINFO 2025-07-13 00:30:20,950[0m:      
[92mINFO 2025-07-13 00:30:20,951[0m:      [ROUND 22]
[92mINFO 2025-07-13 00:30:20,951[0m:      configure_fit: strategy sampled 2 clients (out of 2)


[36m(ClientAppActor pid=21468)[0m Client's model: ETM_Client0
[36m(ClientAppActor pid=21468)[0m Epoch: 000 | Loss: 786.73486328125


[92mINFO 2025-07-13 00:30:23,625[0m:      aggregate_fit: received 2 results and 0 failures


[36m(ClientAppActor pid=21468)[0m Client's model: ETM_Client1
[36m(ClientAppActor pid=21468)[0m Epoch: 000 | Loss: 771.3928833007812
Saving round 22 aggregated_ndarrays...


[92mINFO 2025-07-13 00:30:23,718[0m:      configure_evaluate: strategy sampled 2 clients (out of 2)
[92mINFO 2025-07-13 00:30:25,782[0m:      aggregate_evaluate: received 2 results and 0 failures
[92mINFO 2025-07-13 00:30:25,782[0m:      
[92mINFO 2025-07-13 00:30:25,783[0m:      [ROUND 23]
[92mINFO 2025-07-13 00:30:25,784[0m:      configure_fit: strategy sampled 2 clients (out of 2)


[36m(ClientAppActor pid=21468)[0m Client's model: ETM_Client0
[36m(ClientAppActor pid=21468)[0m Epoch: 000 | Loss: 785.0015869140625


[92mINFO 2025-07-13 00:30:28,481[0m:      aggregate_fit: received 2 results and 0 failures


[36m(ClientAppActor pid=21468)[0m Client's model: ETM_Client1
[36m(ClientAppActor pid=21468)[0m Epoch: 000 | Loss: 769.10302734375
Saving round 23 aggregated_ndarrays...


[92mINFO 2025-07-13 00:30:28,631[0m:      configure_evaluate: strategy sampled 2 clients (out of 2)
[92mINFO 2025-07-13 00:30:32,403[0m:      aggregate_evaluate: received 2 results and 0 failures
[92mINFO 2025-07-13 00:30:32,404[0m:      
[92mINFO 2025-07-13 00:30:32,405[0m:      [ROUND 24]
[92mINFO 2025-07-13 00:30:32,406[0m:      configure_fit: strategy sampled 2 clients (out of 2)


[36m(ClientAppActor pid=21468)[0m Client's model: ETM_Client0
[36m(ClientAppActor pid=21468)[0m Epoch: 000 | Loss: 782.7913208007812


[92mINFO 2025-07-13 00:30:35,857[0m:      aggregate_fit: received 2 results and 0 failures


[36m(ClientAppActor pid=21468)[0m Client's model: ETM_Client1
[36m(ClientAppActor pid=21468)[0m Epoch: 000 | Loss: 766.837890625
Saving round 24 aggregated_ndarrays...


[92mINFO 2025-07-13 00:30:35,992[0m:      configure_evaluate: strategy sampled 2 clients (out of 2)
[92mINFO 2025-07-13 00:30:38,725[0m:      aggregate_evaluate: received 2 results and 0 failures
[92mINFO 2025-07-13 00:30:38,726[0m:      
[92mINFO 2025-07-13 00:30:38,727[0m:      [ROUND 25]
[92mINFO 2025-07-13 00:30:38,728[0m:      configure_fit: strategy sampled 2 clients (out of 2)


[36m(ClientAppActor pid=21468)[0m Client's model: ETM_Client0
[36m(ClientAppActor pid=21468)[0m Epoch: 000 | Loss: 781.1022338867188


[92mINFO 2025-07-13 00:30:41,838[0m:      aggregate_fit: received 2 results and 0 failures


[36m(ClientAppActor pid=21468)[0m Client's model: ETM_Client1
[36m(ClientAppActor pid=21468)[0m Epoch: 000 | Loss: 765.5335693359375
Saving round 25 aggregated_ndarrays...


[92mINFO 2025-07-13 00:30:41,946[0m:      configure_evaluate: strategy sampled 2 clients (out of 2)
[92mINFO 2025-07-13 00:30:44,282[0m:      aggregate_evaluate: received 2 results and 0 failures
[92mINFO 2025-07-13 00:30:44,283[0m:      
[92mINFO 2025-07-13 00:30:44,283[0m:      [ROUND 26]
[92mINFO 2025-07-13 00:30:44,284[0m:      configure_fit: strategy sampled 2 clients (out of 2)


[36m(ClientAppActor pid=21468)[0m Client's model: ETM_Client0
[36m(ClientAppActor pid=21468)[0m Epoch: 000 | Loss: 779.3370971679688


[92mINFO 2025-07-13 00:30:46,813[0m:      aggregate_fit: received 2 results and 0 failures
[92mINFO 2025-07-13 00:30:46,911[0m:      configure_evaluate: strategy sampled 2 clients (out of 2)


[36m(ClientAppActor pid=21468)[0m Client's model: ETM_Client1
[36m(ClientAppActor pid=21468)[0m Epoch: 000 | Loss: 763.9505004882812
Saving round 26 aggregated_ndarrays...


[92mINFO 2025-07-13 00:30:48,796[0m:      aggregate_evaluate: received 2 results and 0 failures
[92mINFO 2025-07-13 00:30:48,797[0m:      
[92mINFO 2025-07-13 00:30:48,797[0m:      [ROUND 27]
[92mINFO 2025-07-13 00:30:48,799[0m:      configure_fit: strategy sampled 2 clients (out of 2)


[36m(ClientAppActor pid=21468)[0m Client's model: ETM_Client0
[36m(ClientAppActor pid=21468)[0m Epoch: 000 | Loss: 777.7408447265625


[92mINFO 2025-07-13 00:30:51,356[0m:      aggregate_fit: received 2 results and 0 failures


[36m(ClientAppActor pid=21468)[0m Client's model: ETM_Client1
[36m(ClientAppActor pid=21468)[0m Epoch: 000 | Loss: 762.429443359375
Saving round 27 aggregated_ndarrays...


[92mINFO 2025-07-13 00:30:51,498[0m:      configure_evaluate: strategy sampled 2 clients (out of 2)
[92mINFO 2025-07-13 00:30:53,610[0m:      aggregate_evaluate: received 2 results and 0 failures
[92mINFO 2025-07-13 00:30:53,612[0m:      
[92mINFO 2025-07-13 00:30:53,612[0m:      [ROUND 28]
[92mINFO 2025-07-13 00:30:53,613[0m:      configure_fit: strategy sampled 2 clients (out of 2)


[36m(ClientAppActor pid=21468)[0m Client's model: ETM_Client0
[36m(ClientAppActor pid=21468)[0m Epoch: 000 | Loss: 776.7052612304688


[92mINFO 2025-07-13 00:30:56,043[0m:      aggregate_fit: received 2 results and 0 failures


[36m(ClientAppActor pid=21468)[0m Client's model: ETM_Client1
[36m(ClientAppActor pid=21468)[0m Epoch: 000 | Loss: 761.0396118164062


[92mINFO 2025-07-13 00:30:56,149[0m:      configure_evaluate: strategy sampled 2 clients (out of 2)


Saving round 28 aggregated_ndarrays...


[92mINFO 2025-07-13 00:30:58,154[0m:      aggregate_evaluate: received 2 results and 0 failures
[92mINFO 2025-07-13 00:30:58,155[0m:      
[92mINFO 2025-07-13 00:30:58,156[0m:      [ROUND 29]
[92mINFO 2025-07-13 00:30:58,157[0m:      configure_fit: strategy sampled 2 clients (out of 2)


[36m(ClientAppActor pid=21468)[0m Client's model: ETM_Client0
[36m(ClientAppActor pid=21468)[0m Epoch: 000 | Loss: 775.1154174804688


[92mINFO 2025-07-13 00:31:00,603[0m:      aggregate_fit: received 2 results and 0 failures


[36m(ClientAppActor pid=21468)[0m Client's model: ETM_Client1
[36m(ClientAppActor pid=21468)[0m Epoch: 000 | Loss: 760.0211181640625
Saving round 29 aggregated_ndarrays...


[92mINFO 2025-07-13 00:31:00,702[0m:      configure_evaluate: strategy sampled 2 clients (out of 2)
[92mINFO 2025-07-13 00:31:02,917[0m:      aggregate_evaluate: received 2 results and 0 failures
[92mINFO 2025-07-13 00:31:02,918[0m:      
[92mINFO 2025-07-13 00:31:02,918[0m:      [ROUND 30]
[92mINFO 2025-07-13 00:31:02,920[0m:      configure_fit: strategy sampled 2 clients (out of 2)


[36m(ClientAppActor pid=21468)[0m Client's model: ETM_Client0
[36m(ClientAppActor pid=21468)[0m Epoch: 000 | Loss: 774.0159912109375


[92mINFO 2025-07-13 00:31:05,500[0m:      aggregate_fit: received 2 results and 0 failures


[36m(ClientAppActor pid=21468)[0m Client's model: ETM_Client1
[36m(ClientAppActor pid=21468)[0m Epoch: 000 | Loss: 758.583984375


[92mINFO 2025-07-13 00:31:05,617[0m:      configure_evaluate: strategy sampled 2 clients (out of 2)


Saving round 30 aggregated_ndarrays...


[92mINFO 2025-07-13 00:31:07,745[0m:      aggregate_evaluate: received 2 results and 0 failures
[92mINFO 2025-07-13 00:31:07,750[0m:      
[92mINFO 2025-07-13 00:31:07,751[0m:      [SUMMARY]
[92mINFO 2025-07-13 00:31:07,752[0m:      Run finished 30 round(s) in 156.70s
[92mINFO 2025-07-13 00:31:07,753[0m:      	History (loss, distributed):
[92mINFO 2025-07-13 00:31:07,754[0m:      		round 1: -1.0
[92mINFO 2025-07-13 00:31:07,754[0m:      		round 2: -1.0
[92mINFO 2025-07-13 00:31:07,755[0m:      		round 3: -1.0
[92mINFO 2025-07-13 00:31:07,756[0m:      		round 4: -1.0
[92mINFO 2025-07-13 00:31:07,756[0m:      		round 5: -1.0
[92mINFO 2025-07-13 00:31:07,758[0m:      		round 6: -1.0
[92mINFO 2025-07-13 00:31:07,759[0m:      		round 7: -1.0
[92mINFO 2025-07-13 00:31:07,759[0m:      		round 8: -1.0
[92mINFO 2025-07-13 00:31:07,760[0m:      		round 9: -1.0
[92mINFO 2025-07-13 00:31:07,760[0m:      		round 10: -1.0
[92mINFO 2025-07-13 00:31:07,760[0m:      		ro

In [97]:
from utils._utils import get_top_words
net = ETM(len(vocab))
test = get_latest_server_model(net)
beta = net.get_beta().detach().cpu().numpy()
topwords = get_top_words(beta, vocab, 15, verbose=True)

Loading pre-trained model from:  model_parameters\model_round_30.npz
Topic 0: sure stop necessary recently write research child event users provided wife russian names closed launch
Topic 1: article trying thought opinions book source certainly claim pub wrote major tim moon electronic comment
Topic 2: want people actually run looking memory faq difference short red crime keith political images mentioned
Topic 3: bit car war questions reason hit fast device deleted chance thank turn jon freedom unix
Topic 4: problems example makes ask school simple ibm form air north bbs oil necessarily answer space
Topic 5: university lines article great general mark getting large days view company common bob record opinion
Topic 6: new far phone experience strong save motif let created printer attack appropriate came copy language
Topic 7: possible game place sun keywords ftp sense self built body worth computing anybody pro port
Topic 8: email type current considered happened wide member stupid obje

In [None]:
trainer = BasicTrainer(net, datasets[0])

In [107]:
net.to(DEVICE)

ETM(
  (encoder1): Sequential(
    (0): Linear(in_features=5000, out_features=800, bias=True)
    (1): ReLU()
    (2): Linear(in_features=800, out_features=800, bias=True)
    (3): ReLU()
    (4): Dropout(p=0.0, inplace=False)
  )
  (fc21): Linear(in_features=800, out_features=50, bias=True)
  (fc22): Linear(in_features=800, out_features=50, bias=True)
)

In [109]:
########################### test new documents ####################################
from data.preprocess import Preprocess

preprocess = Preprocess()

new_docs = [
    "This is a new document about space, including words like space, satellite, launch, orbit.",
    "This is a new document about Microsoft Windows, including words like windows, files, dos."
]

parsed_new_docs, new_bow = preprocess.parse(new_docs, vocab)
print(new_bow.shape)

print(new_bow.toarray())
input = torch.as_tensor(new_bow.toarray(), device="cuda").float()
print(input)
new_theta = trainer.test(input)

print(new_theta.argmax(1))
for x in new_theta.argmax(1):
    print(topwords[x])

parsing texts: 100%|██████████| 2/2 [00:00<?, ?it/s]

(2, 5000)
[[0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]]
tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]], device='cuda:0')
[39 12]
people years help things space order armenian care feel known police books needs previous ideas
god windows real yes department apr written algorithm generally display especially understanding research practice cpu



