In [1]:
import flgo.algorithm.fedavg as fedavg
import flgo.experiment.analyzer
import flgo.experiment.logger as fel
import flgo.simulator.base
import flgo.benchmark.cifar10_classification as cifar
import flgo.benchmark.partition as fbp
import os
import flgo.simulator.base
import random
# 1. 定义两种活跃度分布用于对比
# 1.1 第一种分布用户状态转移的概率均为0.1，即用户的状态难以改变，将体现在用户连续活跃时长和连续非活跃时长较长，活跃度变化频率低
class MySimulator(flgo.simulator.base.BasicSimulator):
    def update_client_availability(self):
        if self.gv.clock.current_time==0:
            self.set_variable(self.all_clients, 'prob_available', [1 for _ in self.clients])
            self.set_variable(self.all_clients, 'prob_unavailable', [int(random.random() >= 0.5) for _ in self.clients])
            return
        pa = [0.1 for _ in self.clients]
        pua = [0.1 for _ in self.clients]
        self.set_variable(self.all_clients, 'prob_available', pa)
        self.set_variable(self.all_clients, 'prob_unavailable', pua)

# 1.2 第二种分布用户状态转移的概率均为0.9，即用户的状态极易改变，将体现在用户活跃度变化频率高
class MySimulator2(flgo.simulator.base.BasicSimulator):
    def update_client_availability(self):
        if self.gv.clock.current_time==0:
            self.set_variable(self.all_clients, 'prob_available', [1 for _ in self.clients])
            self.set_variable(self.all_clients, 'prob_unavailable', [int(random.random() >= 0.5) for _ in self.clients])
            return
        pa = [0.9 for _ in self.clients]
        pua = [0.9 for _ in self.clients]
        self.set_variable(self.all_clients, 'prob_available', pa)
        self.set_variable(self.all_clients, 'prob_unavailable', pua)

# 2. 生成联邦任务以测试
task = './IID_cifar10'
gen_config = {
    'benchmark': cifar,
    'partitioner': fbp.IIDPartitioner
}
if not os.path.exists(task): flgo.gen_task(gen_config, task_path=task)

# 3. 定制Logger以记录用户的活跃度分布
class MyLogger(fel.BasicLogger):
    def log_once(self, *args, **kwargs):
        if self.gv.clock.current_time==0: return
        self.output['available_clients'].append(self.coordinator.available_clients)
        print(self.output['available_clients'][-1])

# if __name__ == '__main__':
#     # 4. 在两种模拟环境下分别运行联邦任务
#     runner1 = flgo.init(task, fedavg, {'gpu':[0,],'log_file':True, 'num_steps':1, 'num_rounds':100}, Logger=MyLogger, Simulator=MySimulator)
#     runner1.run()
#     runner2 = flgo.init(task, fedavg, {'gpu':[0,],'log_file':True, 'num_steps':1, 'num_rounds':100}, Logger=MyLogger, Simulator=MySimulator2)
#     runner2.run()

#     # 5. 可视化用户活跃度分布
#     selector = flgo.experiment.analyzer.Selector({'task':task, 'header':['fedavg',], })
#     def visualize_availability(rec_data, title = ''):
#         avl_clients = rec_data['available_clients']
#         all_points_x = []
#         all_points_y = []
#         for round in range(len(avl_clients)):
#             all_points_x.extend([round + 1 for _ in avl_clients[round]])
#             all_points_y.extend([cid for cid in avl_clients[round]])
#         import matplotlib.pyplot as plt
#         plt.scatter(all_points_x, all_points_y, s=10)
#         plt.title(title)
#         plt.xlabel('communication round')
#         plt.ylabel('client ID')
#         plt.show()
#     rec0 = selector.records[task][0]
#     visualize_availability(rec0.data, rec0.name[rec0.name.find('_SIM')+4:rec0.name.find('_SIM')+16])
#     rec1 = selector.records[task][1]
#     visualize_availability(rec1.data, rec1.name[rec1.name.find('_SIM')+4:rec1.name.find('_SIM')+16])


In [None]:
import flgo.algorithm.fedavg as fedavg
import flgo.experiment.logger as fel
import flgo.simulator.base
import flgo.benchmark.cifar10_classification as cifar
import flgo.benchmark.partition as fbp
import os
import random

# 1.1 第一种分布用户状态转移的概率均为0.1
class MySimulator(flgo.simulator.base.BasicSimulator):
    def update_client_availability(self):
        if self.gv.clock.current_time == 0:
            self.set_variable(self.all_clients, 'prob_available', [1 for _ in self.clients])
            self.set_variable(self.all_clients, 'prob_unavailable', [int(random.random() >= 0.5) for _ in self.clients])
            return
        pa = [0.1 for _ in self.clients]
        pua = [0.1 for _ in self.clients]
        self.set_variable(self.all_clients, 'prob_available', pa)
        self.set_variable(self.all_clients, 'prob_unavailable', pua)

# 2. 生成联邦任务以测试
task = './IID_cifar10'
gen_config = {
    'benchmark': cifar,
    'partitioner': fbp.IIDPartitioner
}
if not os.path.exists(task):
    flgo.gen_task(gen_config, task_path=task)

# 3. 定制Logger以记录用户的活跃度分布
class MyLogger(fel.BasicLogger):
    def log_once(self, *args, **kwargs):
        if self.gv.clock.current_time == 0:
            return
        self.output['available_clients'].append(self.coordinator.available_clients)
        print(self.output['available_clients'][-1])

if __name__ == '__main__':
    # 4. 在第一种模拟环境下运行联邦任务
    runner = flgo.init(task, fedavg, {'gpu':[0,], 'log_file':True, 'num_steps':1, 'num_rounds':100}, Logger=MyLogger, Simulator=MySimulator)
    runner.run()

    # 5. 可视化用户活跃度分布
    import flgo.experiment.analyzer
    import matplotlib.pyplot as plt

    selector = flgo.experiment.analyzer.Selector({'task':task, 'header':['fedavg',], })
    
    def visualize_availability(rec_data, title=''):
        avl_clients = rec_data['available_clients']
        all_points_x = []
        all_points_y = []
        for round in range(len(avl_clients)):
            all_points_x.extend([round + 1 for _ in avl_clients[round]])
            all_points_y.extend([cid for cid in avl_clients[round]])
        plt.scatter(all_points_x, all_points_y, s=10)
        plt.title(title)
        plt.xlabel('communication round')
        plt.ylabel('client ID')
        plt.show()

    rec = selector.records[task][0]
    visualize_availability(rec.data, rec.name[rec.name.find('_SIM')+4:rec.name.find('_SIM')+16])

In [None]:
import flgo.algorithm.fedavg as fedavg
import flgo.experiment.logger as fel
import flgo.simulator.base
import flgo.benchmark.cifar10_classification as cifar
import flgo.benchmark.partition as fbp
import os
import random

# 1.2 第二种分布用户状态转移的概率均为0.9
class MySimulator2(flgo.simulator.base.BasicSimulator):
    def update_client_availability(self):
        if self.gv.clock.current_time == 0:
            self.set_variable(self.all_clients, 'prob_available', [1 for _ in self.clients])
            self.set_variable(self.all_clients, 'prob_unavailable', [int(random.random() >= 0.5) for _ in self.clients])
            return
        pa = [0.9 for _ in self.clients]
        pua = [0.9 for _ in self.clients]
        self.set_variable(self.all_clients, 'prob_available', pa)
        self.set_variable(self.all_clients, 'prob_unavailable', pua)

# 2. 生成联邦任务以测试
task = './IID_cifar10'
gen_config = {
    'benchmark': cifar,
    'partitioner': fbp.IIDPartitioner
}
if not os.path.exists(task):
    flgo.gen_task(gen_config, task_path=task)

# 3. 定制Logger以记录用户的活跃度分布
class MyLogger(fel.BasicLogger):
    def log_once(self, *args, **kwargs):
        if self.gv.clock.current_time == 0:
            return
        self.output['available_clients'].append(self.coordinator.available_clients)
        print(self.output['available_clients'][-1])

if __name__ == '__main__':
    # 4. 在第二种模拟环境下运行联邦任务
    runner = flgo.init(task, fedavg, {'gpu':[0,], 'log_file':True, 'num_steps':1, 'num_rounds':100}, Logger=MyLogger, Simulator=MySimulator2)
    runner.run()

    # 5. 可视化用户活跃度分布
    import flgo.experiment.analyzer
    import matplotlib.pyplot as plt

    selector = flgo.experiment.analyzer.Selector({'task':task, 'header':['fedavg',], })
    
    def visualize_availability(rec_data, title=''):
        avl_clients = rec_data['available_clients']
        all_points_x = []
        all_points_y = []
        for round in range(len(avl_clients)):
            all_points_x.extend([round + 1 for _ in avl_clients[round]])
            all_points_y.extend([cid for cid in avl_clients[round]])
        plt.scatter(all_points_x, all_points_y, s=10)
        plt.title(title)
        plt.xlabel('communication round')
        plt.ylabel('client ID')
        plt.show()

    rec = selector.records[task][0]
    visualize_availability(rec.data, rec.name[rec.name.find('_SIM')+4:rec.name.find('_SIM')+16])