In [1]:
class Args:
    output_dir = "./saved_models"
    model_type = "roberta"
    config_name = "../../../model/codebert-base/config.json"
    model_name_or_path = "../../../model/codebert-base"
    tokenizer_name = "../../../model/codebert-base"
    train_data_dir = "../../tags/splits_even/"
    eval_data_file = "../../dataset/valid.txt"
    test_data_file = "../../dataset/test.txt"
    epoch = 2
    block_size = 400
    train_batch_size = 16
    eval_batch_size = 32
    learning_rate = 5e-5
    max_grad_norm = 1.0
    seed = 123456
    # 以下是没有明确提供的参数，我将使用默认值
    mlm = False
    mlm_probability = 0.15
    cache_dir = ""
    gradient_accumulation_steps = 1
    adam_epsilon = 1e-8
    max_steps = -1
    warmup_steps = 0
    logging_steps = 50
    eval_all_checkpoints = False
    no_cuda = False
    overwrite_output_dir = False
    overwrite_cache = False
    fp16 = False
    fp16_opt_level = 'O1'
    local_rank = -1
    server_ip = ''
    server_port = ''
    do_lower_case = ''

args = Args()


In [2]:
from collections import OrderedDict

from centralized import load_data, load_model, train, evaluate, DEVICE, parse_args
from torch.utils.data import DataLoader

import flwr as fl
import torch
import ray
import logging

In [3]:
logging.basicConfig(filename='training.log', level=logging.INFO)

def set_parameters(model, parameters):
    params_dict = zip(model.state_dict().keys(), parameters)
    state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
    model.load_state_dict(state_dict, strict=True)
    return model

logging.info("Loading model")

try:
    model, args = load_model(args)
    logging.info("Loading model succesfully")
except Exception as e:
    logging.error(e)

args.n_gpu = 1
args.per_gpu_train_batch_size=args.train_batch_size//args.n_gpu
args.per_gpu_eval_batch_size=args.eval_batch_size//args.n_gpu
args.device = DEVICE

In [4]:
class FlowerClient(fl.client.NumPyClient):
    def __init__(self, model, trainset, testset, cid):
        self.model = model
        self.trainset = trainset
        self.testset = testset
        self.cid = cid

    def get_parameters(self, config):
        return [val.cpu().numpy() for _, val in model.state_dict().items()]

    def fit(self, parameters, config):
        set_parameters(self.model, parameters)
        train(args, self.model, self.trainset, self.cid)
        return self.get_parameters({}), len(self.trainset), {}

    def evaluate(self, parameters, config):
        set_parameters(self.model, parameters)
        eval_loss, result = evaluate(args, self.model, self.testset)
        return eval_loss, len(self.testset), result
    

def client_fn(cid: str) -> FlowerClient:
    """Create a Flower client representing a single organization."""
    logging.info(f"Client {cid} is starting training...")
    
    trainset, testset = load_data(cid=cid, args=args)

    print(len(testset))

    # Create a client-specific Flower client
    return FlowerClient(model, trainset, testset, cid)

In [5]:
client_fn(1)

split_1: 90103
split_1: 9010


100%|██████████| 9010/9010 [00:07<00:00, 1238.56it/s]


valid: 415416
valid: 41541


100%|██████████| 41541/41541 [00:11<00:00, 3524.82it/s]


41541


<__main__.FlowerClient at 0x7efb25406620>

In [None]:
class CustomFedAvg(fl.server.strategy.FedAvg):
    def aggregate_fit(self, rnd, results, failures):
        logging.info(f"Server: Aggregating parameters in round {rnd}")
        aggregated_parameters = super().aggregate_fit(rnd, results, failures)
        if aggregated_parameters is not None:
            logging.info("Server: Parameters aggregated successfully")
        else:
            logging.info("Server: Parameter aggregation failed")
        return aggregated_parameters

In [None]:
NUM_CLIENTS = 10
client_resources = {"num_cpus": 8, "num_gpus": 1.0}

In [None]:
# fl.simulation.start_simulation(
#     client_fn=client_fn,
#     num_clients=NUM_CLIENTS,
#     config=fl.server.ServerConfig(num_rounds=2),
#     client_resources=client_resources,
#     strategy=CustomFedAvg()
# )