In [1]:
import os
os.chdir('./')

In [2]:
# federated_learning.py
import torch
from data.data_generation import generate_data
from model.model import LinearRegressionModel
from training.train import train_local_model, federated_averaging

def federated_learning(num_clients=5, num_rounds=10):
    """
    联邦学习模拟
    """
    # 初始化全局模型
    global_model = LinearRegressionModel()

    # 每个客户端的数据
    client_data = [generate_data(100) for _ in range(num_clients)]  # 每个客户端100个样本

    for round_num in range(num_rounds):
        print(f"Round {round_num+1}/{num_rounds}")
        client_models = []
        
        # 每个客户端训练自己的本地模型
        for client_id in range(num_clients):
            model = LinearRegressionModel()  # 每个客户端从头开始
            model.load_state_dict(global_model.state_dict())  # 加载全局模型的参数
            data, targets = client_data[client_id]
            local_model_state_dict = train_local_model(model, data, targets)
            client_models.append(local_model_state_dict)
        
        # 聚合客户端的模型更新
        global_model_state_dict = federated_averaging(client_models)
        global_model.load_state_dict(global_model_state_dict)

        # 打印当前全局模型的参数
        print(f"Global model parameters: {global_model.state_dict()}")

    return global_model

global_model = federated_learning(num_clients=5, num_rounds=10)
print("Final global model parameters:")
print(global_model.state_dict())


Round 1/10
Global model parameters: OrderedDict([('fc.weight', tensor([[1.6332]])), ('fc.bias', tensor([1.2021]))])
Round 2/10
Global model parameters: OrderedDict([('fc.weight', tensor([[1.7174]])), ('fc.bias', tensor([1.1572]))])
Round 3/10
Global model parameters: OrderedDict([('fc.weight', tensor([[1.7819]])), ('fc.bias', tensor([1.1226]))])
Round 4/10
Global model parameters: OrderedDict([('fc.weight', tensor([[1.8314]])), ('fc.bias', tensor([1.0961]))])
Round 5/10
Global model parameters: OrderedDict([('fc.weight', tensor([[1.8693]])), ('fc.bias', tensor([1.0757]))])
Round 6/10
Global model parameters: OrderedDict([('fc.weight', tensor([[1.8984]])), ('fc.bias', tensor([1.0601]))])
Round 7/10
Global model parameters: OrderedDict([('fc.weight', tensor([[1.9207]])), ('fc.bias', tensor([1.0482]))])
Round 8/10
Global model parameters: OrderedDict([('fc.weight', tensor([[1.9378]])), ('fc.bias', tensor([1.0390]))])
Round 9/10
Global model parameters: OrderedDict([('fc.weight', tensor([[