In [None]:
import warnings
warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)
import flwr as fl
from flwr_datasets import FederatedDataset
from flwr_datasets.partitioner import IidPartitioner
from datasets import load_dataset
from flwr.client.mod import fixedclipping_mod
from flwr.server.strategy import (
    DifferentialPrivacyClientSideFixedClipping
)

from transformers import AutoModelForCausalLM, AutoTokenizer

from utils.utils import * 

In [3]:
cfg = get_config("federated_full")

print_config(cfg)

dataset:
  path: ./data/{}.json
  name: movieKnowledgeGraphDatasetWithSyntheticData
model:
  name: Qwen/Qwen3-4B
  quantization: 4
  gradient_checkpointing: true
  use_fast_tokenizer: false
  lora:
    peft_lora_r: 16
    peft_lora_alpha: 64
    target_modules:
    - q_proj
    - v_proj
train:
  num_rounds: ${flower.num_rounds}
  save_every_round: 5
  learning_rate_max: 5.0e-05
  learning_rate_min: 1.0e-06
  seq_length: 2048
  padding_side: left
  evaluate_split: true
  training_arguments:
    output_dir: null
    learning_rate: null
    per_device_train_batch_size: 16
    gradient_accumulation_steps: 1
    logging_steps: 10
    num_train_epochs: 3
    max_steps: 10
    report_to: null
    save_steps: 1000
    save_total_limit: 10
    gradient_checkpointing: ${model.gradient_checkpointing}
    lr_scheduler_type: constant
client_resources:
  num_cpus: 8
  num_gpus: 1.0
dp:
  noise_mult: 0.02
  clip_norm: 0.5
flower:
  num_clients: null
  num_rounds: 1
  fraction_fit: 0.02
  client_resou

In [4]:
with open(cfg.dataset.path.format(cfg.dataset.name), "r") as file:
    datasets = json.load(file)
cfg.flower.num_clients = len(datasets.keys())
print(cfg.flower.num_clients)

519


In [5]:
# ===== Define the tokenizer =====
tokenizer = AutoTokenizer.from_pretrained(
    cfg.model.name,
    use_fast=cfg.model.use_fast_tokenizer,
    padding_side=cfg.train.padding_side,
)
if tokenizer.pad_token is None:
    tokenizer.pad_token = (
        tokenizer.bos_token if cfg.train.padding_side == "left" else tokenizer.eos_token
    )
print(f"pad_token_id: {tokenizer.pad_token_id}")

pad_token_id: 151643


In [None]:
save_path = f"./models/{cfg.model.name}/{cfg.dataset.name}"
client = fl.client.ClientApp(
    client_fn=gen_client_fn(
        datasets,
        tokenizer,
        cfg.model,
        cfg.train,
        save_path,
    ),
    # mods=[fixedclipping_mod] # For Differential Privacy
)

In [None]:
def server_fn(context: Context):

    # Define the Strategy
    strategy = fl.server.strategy.FedAvg(
        min_available_clients=cfg.flower.num_clients, # total clients
        fraction_fit=cfg.flower.sample_clients/cfg.flower.num_clients, # ratio of clients to sample
        fraction_evaluate=0.0, # No federated evaluation
        # A (optional) function used to configure a "fit()" round
        on_fit_config_fn=get_on_fit_config(),
        # A (optional) function to aggregate metrics sent by clients
        fit_metrics_aggregation_fn=fit_weighted_average,
        # A (optional) function to execute on the server after each round. 
        # In this example the function only saves the global model.
        evaluate_fn=get_evaluate_fn( 
            cfg.model,
            cfg.train.save_every_round,
            cfg.flower.num_rounds,
            save_path
        ),
    )

    # # Add Differential Privacy
    # sampled_clients = cfg.flower.num_clients*strategy.fraction_fit
    # strategy = DifferentialPrivacyClientSideFixedClipping(
    #     strategy, 
    #     noise_multiplier=cfg.flower.dp.noise_mult,
    #     clipping_norm=cfg.flower.dp.clip_norm, 
    #     num_sampled_clients=sampled_clients
    # )

    # Number of rounds to run the simulation
    num_rounds = cfg.flower.num_rounds
    config = fl.server.ServerConfig(num_rounds=num_rounds)
    
    return fl.server.ServerAppComponents(strategy=strategy, config=config)

server = fl.server.ServerApp(server_fn=server_fn)

In [None]:
client_resources = dict(cfg.flower.client_resources)
fl.simulation.run_simulation(
    server_app=server,
    client_app=client,
    num_supernodes=cfg.flower.num_clients,
    backend_config={"client_resources": client_resources,
                    "init_args": backend_setup}
)

[92mINFO [0m: Starting Flower ServerApp, config: num_rounds=1, no round_timeout
[92mINFO [0m: 
[92mINFO [0m: [INIT]
[92mINFO [0m: Requesting initial parameters from one random client
