# 定义算法

In [None]:
import flgo.algorithm.fedbase as fedbase
import flgo.utils.fmodule as fmodule
import copy
import numpy as np
import torch

class Server(fedbase.BasicServer):
    def initialize(self, *args, **kwargs):
        # 初始化算法参数，包括客户端数量、控制参数、梯度影响因子等
        algo_params = {
            'd': self.num_clients,  # 每轮预设选取的客户端数量
            'alpha': 1,  # 用于计算委员会节点数 K 的参数
            'K_min': 1,  # 最小委员会节点数，确保至少有一个客户端在委员会中
            'delta_grad': 0.1,  # 用于委员会节点得分计算中的梯度惩罚因子
            'gamma': 0.1,  # 用于得分计算中的时间衰减因子
            'w_data': 0.5,  # 数据量权重，用于聚合模型时的加权
            'selected_round': 1,  # 每隔多少轮更新一次委员会
        }
        self.init_algo_para(algo_params)
        self.gv.logger.info(f"Initialization parameters: {algo_params}")
        self.gv.logger.write_var_into_output("initialization_parameters", algo_params)

        # 初始化委员会列表和客户端的参与记录与得分记录
        self.committee = []
        self.last_participation = {cid: 0 for cid in range(self.num_clients)}
        self.scores = {cid: 0.0 for cid in range(self.num_clients)}
        
    def sample(self):
        # 计算本轮参与训练的客户端总数N和委员会大小K
        N = min(self.d, len(self.available_clients))
        K = max(min(N//3 + 1, self.alpha * np.log(N+1)), self.K_min)
        K = int(K)

        # 每隔 selected_round 轮或首次运行时，基于得分排序选择委员会成员
        if not self.committee or self.current_round % self.selected_round == 0:
            sorted_clients = sorted(self.available_clients, key=lambda x: self.scores[x], reverse=True)
            self.committee = sorted_clients[:K]
            # 记录委员会成员更新信息
            self.gv.logger.info(f"Round {self.current_round}: Committee updated - {self.committee}")
            self.gv.logger.write_var_into_output(f"committee_round_{self.current_round}", list(self.committee))

        # 从非委员会成员中随机选择普通客户端，确保参与总数为N
        normal_clients = [cid for cid in self.available_clients if cid not in self.committee]
        selected_normal = np.random.choice(normal_clients, min(N-K, len(normal_clients)), replace=False)
        selected_normal = list(map(int, selected_normal))  # 确保是整数列表
        selected_normal.sort()

        # 合并委员会成员和随机选择的普通客户端，形成最终的参与客户端列表
        selected_clients = list(self.committee) + selected_normal
        selected_clients.sort()

        # 记录普通客户端和最终选择的客户端列表
        self.gv.logger.info(f"Round {self.current_round}: Normal clients selected - {selected_normal}")
        self.gv.logger.write_var_into_output(f"normal_clients_round_{self.current_round}", selected_normal)
        self.gv.logger.info(f"Round {self.current_round}: All selected clients - {selected_clients}")
        self.gv.logger.write_var_into_output(f"selected_clients_round_{self.current_round}", selected_clients)
        return selected_clients

    def iterate(self):
        self.gv.logger.info(f"Round {self.current_round} started")
        self.gv.logger.write_var_into_output(f"round_start", self.current_round)

        self.selected_clients = self.sample()
        res = self.communicate(self.selected_clients)
        models = res['model']

        for cid, model in zip(self.selected_clients, models):
            local_acc = self.clients[cid].test(model)['accuracy']
            global_acc = self.test(self.model)['accuracy']

            grad = fmodule._model_sub(self.model, model)
            grad_norm = fmodule._model_norm(grad)
            smoothed_grad_norm = torch.log(1 + grad_norm)

            time_decay = np.exp(-self.gamma * (self.current_round - self.last_participation[cid]))

            if cid in self.committee:
                base_score = global_acc * (1 - self.delta_grad * smoothed_grad_norm.item())
            else:
                Acc_min = min([self.clients[i].test(models[i])['accuracy'] for i in range(len(models))])
                base_score = max(0, local_acc - Acc_min)

            self.scores[cid] = float(base_score * time_decay)
            self.last_participation[cid] = self.current_round

            info_str = f"Round {self.current_round}, Client {cid}: Local Acc={local_acc:.4f}, Global Acc={global_acc:.4f}, Grad Norm={grad_norm:.4f}, Smoothed Grad Norm={smoothed_grad_norm:.4f}, Time Decay={time_decay:.4f}, Score={self.scores[cid]:.4f}"
            self.gv.logger.info(info_str)
            self.gv.logger.write_var_into_output(f"client_{cid}_metrics_round_{self.current_round}", {
                "local_acc": float(local_acc),
                "global_acc": float(global_acc),
                "grad_norm": float(grad_norm.item()),
                "smoothed_grad_norm": float(smoothed_grad_norm.item()),
                "time_decay": float(time_decay),
                "score": float(self.scores[cid])
            })

        self.model = self.aggregate(models)
        return True

    def aggregate(self, models):
        total_score = sum([self.scores[cid] for cid in self.selected_clients])
        weights = [(self.scores[cid] / total_score) for cid in self.selected_clients]

        for cid, weight in zip(self.selected_clients, weights):
            self.gv.logger.info(f"Round {self.current_round}, Client {cid}: Aggregation Weight={weight:.4f}")
            self.gv.logger.write_var_into_output(f"client_{cid}_aggregation_weight_round_{self.current_round}", float(weight))

        self.gv.logger.info(f"Round {self.current_round}: Total Aggregation Weight={sum(weights):.4f}")
        self.gv.logger.write_var_into_output(f"total_aggregation_weight_round_{self.current_round}", float(sum(weights)))

        return fmodule._model_average(models, weights)

class Client(fedbase.BasicClient):
    def train(self, model):
        return super().train(model)

class MyAlgorithm:
    Server = Server
    Client = Client

# 创建任务

In [None]:
# 单任务

import flgo
import os

task = './tasks/8.18_demo'

gen_config = {
    'benchmark':{'name':'flgo.benchmark.mnist_classification'},
    'partitioner':{'name':'IIDPartitioner', 'para':{'num_clients':30}}
}

# generate federated task if task doesn't exist
if not os.path.exists(task): flgo.gen_task(gen_config, task_path=task)


In [None]:
# 多任务
import flgo
import os

# 定义不同数据异构性配置，并将任务路径放在 tasks/multi-data/ 目录下
configurations = {
    './tasks/multi-data/iid': {'benchmark':{'name':'flgo.benchmark.mnist_classification'}, 'partitioner':{'name':'IIDPartitioner', 'para':{'num_clients':30}}},
    './tasks/multi-data/div01': {'benchmark':{'name':'flgo.benchmark.mnist_classification'}, 'partitioner':{'name':'DiversityPartitioner', 'para':{'num_clients':30, 'diversity':0.1}}},
    './tasks/multi-data/div05': {'benchmark':{'name':'flgo.benchmark.mnist_classification'}, 'partitioner':{'name':'DiversityPartitioner', 'para':{'num_clients':30, 'diversity':0.5}}},
    './tasks/multi-data/div09': {'benchmark':{'name':'flgo.benchmark.mnist_classification'}, 'partitioner':{'name':'DiversityPartitioner', 'para':{'num_clients':30, 'diversity':0.9}}},
    './tasks/multi-data/dir01': {'benchmark':{'name':'flgo.benchmark.mnist_classification'}, 'partitioner':{'name':'DirichletPartitioner', 'para':{'num_clients':30, 'alpha':0.1}}},
    './tasks/multi-data/dir10': {'benchmark':{'name':'flgo.benchmark.mnist_classification'}, 'partitioner':{'name':'DirichletPartitioner', 'para':{'num_clients':30, 'alpha':1.0}}},
    './tasks/multi-data/dir50': {'benchmark':{'name':'flgo.benchmark.mnist_classification'}, 'partitioner':{'name':'DirichletPartitioner', 'para':{'num_clients':30, 'alpha':5.0}}},
}

# 遍历配置，生成相应的任务
for task, config in configurations.items():
    if not os.path.exists(task):
        flgo.gen_task(config, task_path=task)

print("Tasks generated successfully under tasks/multi-data/")

# 跑程序

In [None]:
import flgo.algorithm.fedavg as fedavg

option = {'gpu':[1,],'log_file':True, 'num_rounds':2, 'proportion':1.0, 'learning_rate':0.01, 'num_epochs':10, 'sample':'uniform'}
fedavg_runner = flgo.init(task, fedavg, option=option, )
fedavg_runner.run()
scaffold_runner = flgo.init(task, MyAlgorithm, option=option, )
scaffold_runner.run()

# 分析结果，画图

In [None]:
# import flgo.experiment.analyzer as fea
# records = fea.load_records(task, ['MyAlgorithm'], option)
# # 迭代并打印每个记录的内容
# for record in records:
#     print("Algorithm:", record.algorithm)
#     print("Task:", record.task)
#     print("Options:", record.option)
#     print("Logs:")
#     for key, value in record.log.items():
#         print(f"  {key}: {value}")
#     print("-" * 40)

In [None]:
import flgo.experiment.analyzer as fea
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns

class AnalyzerAndPainter:
    def analyze_and_plot(self,task, option):
        # 加载所有记录
        records = fea.load_records(task, ['MyAlgorithm'], option)

        # 初始化存储数据的结构
        rounds = []
        local_acc = {}
        global_acc = {}
        grad_norm = {}
        time_decay = {}
        score = {}
        agg_weight = {}
        committee_counts = []
        normal_client_counts = []

        for record in records:
            record_rounds = record.log['round_start']
            rounds.extend(record_rounds)
            
            clients = [f'client_{i}_metrics_round_' for i in range(30)]
            
            for client in clients:
                if client not in local_acc:
                    local_acc[client] = []
                    global_acc[client] = []
                    grad_norm[client] = []
                    time_decay[client] = []
                    score[client] = []
                    agg_weight[client] = []

            for r in record_rounds:
                # 委员会和普通客户端数量
                committee_counts.append(len(record.log[f'committee_round_{r}']))
                normal_client_counts.append(len(record.log[f'normal_clients_round_{r}']))
                
                for client in clients:
                    if f'{client}{r}' in record.log:
                        local_acc[client].append(record.log[f'{client}{r}'][0]['local_acc'])
                        global_acc[client].append(record.log[f'{client}{r}'][0]['global_acc'])
                        grad_norm[client].append(record.log[f'{client}{r}'][0]['grad_norm'])
                        time_decay[client].append(record.log[f'{client}{r}'][0]['time_decay'])
                        score[client].append(record.log[f'{client}{r}'][0]['score'])
                        
                        agg_key = f'{client}_aggregation_weight_round_{r}'
                        if agg_key in record.log:
                            agg_weight[client].append(record.log[agg_key][0])
                        else:
                            agg_weight[client].append(0)
                    else:
                        local_acc[client].append(None)
                        global_acc[client].append(None)
                        grad_norm[client].append(None)
                        time_decay[client].append(None)
                        score[client].append(None)
                        agg_weight[client].append(None)

        return rounds, local_acc, global_acc, grad_norm, time_decay, score, agg_weight, committee_counts, normal_client_counts

    def plot_accuracy(self, rounds, local_acc, global_acc, clients):
        plt.figure(figsize=(10, 6))

        for client in clients[:5]:  # 选择前5个客户端
            plt.plot(rounds, local_acc[client], label=f'{client} Local Acc')
            plt.plot(rounds, global_acc[client], '--', label=f'{client} Global Acc')

        plt.xlabel('Round')
        plt.ylabel('Accuracy')
        plt.title('客户端模型准确率和全局模型准确率')
        plt.legend()
        plt.grid(True)
        plt.show()

    def plot_grad_time_decay(self, rounds, grad_norm, time_decay, clients):
        plt.figure(figsize=(10, 6))

        for client in clients[:5]:  # 选择前5个客户端
            plt.plot(rounds, grad_norm[client], label=f'{client} Grad Norm')
            plt.plot(rounds, time_decay[client], '--', label=f'{client} Time Decay')

        plt.xlabel('Round')
        plt.ylabel('Value')
        plt.title('梯度范数和时间衰减因子')
        plt.legend()
        plt.grid(True)
        plt.show()

    def plot_score_agg_weight(self, rounds, score, agg_weight, clients):
        plt.figure(figsize=(10, 6))

        for client in clients[:5]:  # 选择前5个客户端
            plt.plot(rounds, score[client], label=f'{client} Score')
            plt.plot(rounds, agg_weight[client], '--', label=f'{client} Aggregation Weight')

        plt.xlabel('Round')
        plt.ylabel('Value')
        plt.title('客户端得分和聚合权重')
        plt.legend()
        plt.grid(True)
        plt.show()

    def plot_committee_normal_clients(self, rounds, committee_counts, normal_client_counts):
        plt.figure(figsize=(10, 6))
        plt.plot(rounds, committee_counts, label='Committee Members')
        plt.plot(rounds, normal_client_counts, '--', label='Normal Clients')

        plt.xlabel('Round')
        plt.ylabel('Number of Clients')
        plt.title('参与客户端数量和委员会成员变化')
        plt.legend()
        plt.grid(True)
        plt.show()
    
    # 绘制客户端得分热力图
    def plot_heatmap(self, rounds, score):
        # 将score字典转换为DataFrame
        score_df = pd.DataFrame(score, index=rounds)

        # 为了绘制热力图，将DataFrame转换为适合绘图的格式
        score_df = score_df.T  # 转置，使得客户端ID为行，轮次为列

        # 绘制热力图
        plt.figure(figsize=(15, 10))
        sns.heatmap(score_df, cmap='YlOrRd', cbar_kws={'label': 'Score'}, annot=False, fmt=".2f", linewidths=.5)

        plt.xlabel('Round')
        plt.ylabel('Client ID')
        plt.title('Client Scores Across Rounds')
        plt.show()


In [None]:
# 初始化analyzer_and_painter类
ap = AnalyzerAndPainter()

# 调用分析和绘图函数
rounds, local_acc, global_acc, grad_norm, time_decay, score, agg_weight, committee_counts, normal_client_counts = ap.analyze_and_plot(task, option)

# 绘制图像
# ap.plot_accuracy(rounds, local_acc, global_acc, list(local_acc.keys()))
# ap.plot_grad_time_decay(rounds, grad_norm, time_decay, list(grad_norm.keys()))
# ap.plot_score_agg_weight(rounds, score, agg_weight, list(score.keys()))
# ap.plot_committee_normal_clients(rounds, committee_counts, normal_client_counts)
ap.plot_heatmap(rounds, score)