In [1]:
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Python version: 3.6

import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import copy
import numpy as np
import random
from torchvision import datasets, transforms
from collections import deque
import random
import torch
import torch.nn as nn
import torch.nn.functional as F

from utils.sampling import mnist_iid, mnist_noniid, cifar_iid, mnist_iid_drl_local_divide, cifar_iid_drl_local_divide, mnist_noniid_drl_local_divide
from models.Update import LocalUpdate
from models.Update_divide import LocalUpdate_divide
from models.Nets import MLP, CNNMnist, CNNCifar, CNNCifarEmb, CNNCifarEmbReverse, CNNMnistEmb, CNNMnistEmbReverse
from models.Fed import FedAvg, FedPareto
from models.args import args_parser
# from models.test import test_img

# parse args
args = args_parser()
args.device = torch.device('cuda:{}'.format(args.gpu) if torch.cuda.is_available() and args.gpu != -1 else 'cpu')

# load dataset and split users
if args.dataset == 'mnist':
    trans_mnist = transforms.Compose([transforms.ToTensor(), transforms.Lambda(lambda x: x.repeat(1,1,1))])
    dataset_train = datasets.MNIST('./data/mnist/', train=True, download=False, transform=trans_mnist)
    dataset_test = datasets.MNIST('./data/mnist/', train=False, download=True, transform=trans_mnist)
    # sample users
    args.iid = False
    if args.iid:
        dict_users = mnist_iid(dataset_train, args.num_users)
    else:
        dict_users = mnist_noniid(dataset_train, args.num_users)
elif args.dataset == 'cifar':
    trans_cifar = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
    dataset_train = datasets.CIFAR10('./data/cifar', train=True, download=True, transform=trans_cifar)
    dataset_test = datasets.CIFAR10('./data/cifar', train=False, download=True, transform=trans_cifar)
    args.iid = True
    if args.iid:
        dict_users = cifar_iid(dataset_train, args.num_users)
    else:
        exit('Error: only consider IID setting in CIFAR10')
else:
    exit('Error: unrecognized dataset')
img_size = dataset_train[0][0].shape

# build model
if args.model == 'cnn' and args.dataset == 'cifar':
    net_glob = CNNCifar(args=args).to(args.device)
elif args.model == 'cnn' and args.dataset == 'mnist':
    net_glob = CNNMnist(args=args).to(args.device)
elif args.model == 'mlp':
    len_in = 1
    for x in img_size:
        len_in *= x
    net_glob = MLP(dim_in=len_in, dim_hidden=200, dim_out=args.num_classes).to(args.device)
else:
    exit('Error: unrecognized model')
print(net_glob)
net_glob.train()

# copy weights
w_glob = net_glob.state_dict()

CNNMnist(
  (conv1): Conv2d(1, 10, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(10, 20, kernel_size=(5, 5), stride=(1, 1))
  (conv2_drop): Dropout2d(p=0.5, inplace=False)
  (fc1): Linear(in_features=320, out_features=50, bias=True)
  (fc2): Linear(in_features=50, out_features=10, bias=True)
)




In [2]:
# sum(x.numel() for x in net_glob.parameters())

In [3]:
from models.DQN import DQN
from models.buffer import MemoryBuffer
from models.prioritized_buffer import PrioritizedBuffer
from models.test import test_img

In [4]:
#每个local的每一层在emb之后拼接起来，再乘以(100+1)，或者用均值。分别对应是400和100
parameter_dim = (args.num_users+1) * 100
action_dim = args.num_users
print(parameter_dim)

10100


In [5]:
replay_buffer = MemoryBuffer(500)
# replay_buffer = PrioritizedBuffer(100)
dqn = DQN(parameter_dim, action_dim, replay_buffer, args)

In [6]:
layer_dict = {}
layer_name = []
count = 0
for name in w_glob.keys():
    if count % 2 == 0:
        layer_name.append(name.split('.',1)[0])
    count += 1
    
for i in layer_name:
#     layer_dict[i] = CNNCifarEmb(torch.cat([w_glob[i+'.weight'].reshape(1,-1), w_glob[i+'.bias'].reshape(1,-1)], 1).numel())
    layer_dict[i] = CNNMnistEmb(torch.cat([w_glob[i+'.weight'].reshape(1,-1), w_glob[i+'.bias'].reshape(1,-1)], 1).numel())
    
# emb_reverse = CNNCifarEmbReverse(args)
emb_reverse = CNNMnistEmbReverse(args)

optimizer = torch.optim.Adam([
    {'params':layer_dict[layer_name[0]].parameters()},
    {'params':layer_dict[layer_name[1]].parameters()},
    {'params':layer_dict[layer_name[2]].parameters()},
    {'params':layer_dict[layer_name[3]].parameters()},
#     {'params':layer_dict[layer_name[4]].parameters()},
    {'params':emb_reverse.parameters()}
] ,0.01)

w_save = []
for iter in range(args.emb_train_epochs):
    idxs_users = np.random.choice(range(args.num_users), 100, replace=False)
    
    for idx in idxs_users:
#         local = LocalUpdate_divide(args=args, dataset=dataset_train, idxs=dict_users[0][idx])
        local = LocalUpdate(args=args, dataset=dataset_train, idxs=dict_users[idx])
#         w, loss, loss_list = local.train(net=copy.deepcopy(net_glob).to(args.device))
        if iter == 0:
            w, loss = local.train(net=copy.deepcopy(net_glob).to(args.device))
            w_save.append(w)
        if iter > 0:
#             net_self = CNNMnist(args=args).to(args.device)
#             net_self.load_state_dict(w_save[idx])
            w, loss = local.train(net=copy.deepcopy(net_glob).to(args.device))
#             w_save[idx] = w
            
        for i in layer_name:
            if i == 'conv1':
                emb_feature = layer_dict[i].forward(torch.cat([w[i+'.weight'].reshape(1,-1), w[i+'.bias'].reshape(1,-1)], 1).to(args.device))
            else:
                emb_feature += layer_dict[i].forward(torch.cat([w[i+'.weight'].reshape(1,-1), w[i+'.bias'].reshape(1,-1)], 1).to(args.device))
        avg_emb_feature = emb_feature/4
        transform_w = emb_reverse.forward(avg_emb_feature)
        loss_w = [sum((w[i].reshape(1,-1) - transform_w[i].reshape(1,-1)) ** 2) for i in w_glob.keys()]
        loss_avg = 0
        loss_check_dict = {}
        for i in range(len(loss_w)):
            loss_avg += sum(loss_w[i])/len(loss_w[i])
            loss_check_dict[len(loss_w[i])] = loss_w[i]
            if loss_avg.item() > 3.0:
                print('len:{},w:{}'.format(len(loss_w[i]), loss_check_dict[len(loss_w[i])]))
                print('loss_sum:', loss_avg)
                print('\n***************')
        optimizer.zero_grad()
        loss_avg.backward(retain_graph=True)
        optimizer.step()
        print('epoch:{}, loss_avg:{}'.format(iter, loss_avg))

  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)


epoch:0, loss_avg:0.15005645155906677
epoch:0, loss_avg:0.4022198021411896
epoch:0, loss_avg:0.3488524258136749
epoch:0, loss_avg:1.0450294017791748
epoch:0, loss_avg:0.2595255970954895
epoch:0, loss_avg:0.14902356266975403
epoch:0, loss_avg:0.11621211469173431
epoch:0, loss_avg:0.09890931099653244
epoch:0, loss_avg:0.09907639771699905
epoch:0, loss_avg:0.4713357090950012
epoch:0, loss_avg:0.34891873598098755
epoch:0, loss_avg:0.1029767170548439
epoch:0, loss_avg:0.09997280687093735
epoch:0, loss_avg:0.11612743884325027
epoch:0, loss_avg:0.06369136273860931
epoch:0, loss_avg:0.07751049846410751
epoch:0, loss_avg:0.08532143384218216
epoch:0, loss_avg:0.06297529488801956
epoch:0, loss_avg:0.1431894302368164
epoch:0, loss_avg:0.18858301639556885
epoch:0, loss_avg:0.04950165003538132
epoch:0, loss_avg:0.07047183811664581
epoch:0, loss_avg:0.04150144010782242
epoch:0, loss_avg:0.04543745145201683
epoch:0, loss_avg:0.07671115547418594
epoch:0, loss_avg:0.06391266733407974
epoch:0, loss_avg:0

epoch:0, loss_avg:4.644525051116943
epoch:0, loss_avg:0.04336868226528168
epoch:0, loss_avg:0.060976579785346985
epoch:0, loss_avg:0.06987003982067108
epoch:0, loss_avg:0.058429867029190063
epoch:0, loss_avg:0.04810609668493271
epoch:0, loss_avg:0.0905868262052536
epoch:0, loss_avg:0.06640742719173431
epoch:0, loss_avg:0.08427976816892624
epoch:0, loss_avg:0.07063858211040497
epoch:0, loss_avg:0.09314839541912079
epoch:0, loss_avg:0.08681177347898483
epoch:0, loss_avg:0.06721099466085434
epoch:0, loss_avg:0.051178231835365295
epoch:0, loss_avg:0.05336005985736847
epoch:0, loss_avg:0.0630878284573555
epoch:0, loss_avg:0.048411283642053604
epoch:0, loss_avg:0.055836305022239685
epoch:0, loss_avg:0.04145362228155136
epoch:0, loss_avg:0.18913817405700684
epoch:0, loss_avg:0.05807746946811676
epoch:0, loss_avg:0.0625775009393692
epoch:0, loss_avg:0.043869126588106155
epoch:0, loss_avg:0.0521882064640522
epoch:0, loss_avg:0.08007463812828064
epoch:0, loss_avg:0.0765778049826622
epoch:0, loss

epoch:2, loss_avg:0.053771037608385086
epoch:2, loss_avg:0.04329332709312439
epoch:2, loss_avg:0.03486855328083038
epoch:2, loss_avg:0.06222553923726082
epoch:2, loss_avg:0.05538228899240494
epoch:2, loss_avg:0.053843121975660324
epoch:2, loss_avg:0.07116333395242691
epoch:2, loss_avg:0.05350669473409653
epoch:2, loss_avg:0.04448404163122177
epoch:2, loss_avg:0.04406748339533806
epoch:2, loss_avg:0.07138630002737045
epoch:2, loss_avg:0.06176557019352913
epoch:2, loss_avg:0.047291770577430725
epoch:2, loss_avg:0.05378703027963638
epoch:2, loss_avg:0.040329791605472565
epoch:2, loss_avg:0.07213686406612396
epoch:2, loss_avg:0.04487380012869835
epoch:2, loss_avg:0.04402980953454971
epoch:2, loss_avg:0.047428593039512634
epoch:2, loss_avg:0.050896983593702316
epoch:2, loss_avg:0.04606828838586807


In [7]:
# training rl

#先全部训练一次，然后找topk个客户端，这topk个再训练一次，fedavg。之后再从100个里面找topk
#（其他90个就不更新了，也不接收global，除非被选中）


loss_train = []
cv_loss, cv_acc = [], []
val_loss_pre, counter = 0, 0
net_best = None
best_loss = None
val_acc_list, net_list = [], []

#用来装w和更新被选择后训练了的w
w_save = []
loss_save = []

#使用网络采样前可以先多进行一段时间的随机采样来丰富buffer
threshold = 10
m = max(int(args.frac * args.num_users), 1)
constant = 64
target_acc = 0.99
args.emb = False

beta_start = 0.4
beta_frames = 1000 
beta_by_frame = lambda frame_idx: min(1.0, beta_start + frame_idx * (1.0 - beta_start) / beta_frames)

# dqn = torch.load('doubledqn2_{}_{}_{}_clientnumber{}_localep{}.pt'.format(args.dataset, args.epochs, args.model, m, args.local_ep))
# dqn = torch.load('dqn_{}_{}_{}_clientnumber{}_localep{}.pt'.format(args.dataset, args.epochs, args.model, m, args.local_ep))

last_replay_data = []
for iter in range(args.epochs):
    #判断7次（2**6=64），视情况k值减半
#     if iter % int(args.epochs/7) == 0:
#         decay_flag = np.mean(loss_save)
#         if np.mean(loss_save) > decay_flag:
#             args.k = int(args.k * args.k_frac)
#             #buffer要清空，不然optimize时拼接会有问题
#             replay_buffer = PrioritizedBuffer(100)
#             print('dacay')
    
    if iter == 0:
        random_n = 0
        n_weight = []
        while random_n<args.num_users*args.frac:
            n_weight.append(random.random()) # 随机初始化参数
            random_n+=1
#         action = random.sample(range(0,100),1)[0]
        action = random.sample(range(0,100),10)

    #重置global参数    
    if iter % args.reset_flag == 0:
        if args.model == 'cnn' and args.dataset == 'cifar':
            net_glob = CNNCifar(args=args).to(args.device)
        elif args.model == 'cnn' and args.dataset == 'mnist':
            net_glob = CNNMnist(args=args).to(args.device)
        elif args.model == 'mlp':
            len_in = 1
            for x in img_size:
                len_in *= x
            net_glob = MLP(dim_in=len_in, dim_hidden=200, dim_out=args.num_classes).to(args.device)
        else:
            exit('Error: unrecognized model')
        
        
    loss_locals = []
    w_locals = []
    p_emb_collect = []
    for i in layer_name:
        if i == 'conv1':
            emb_global = layer_dict[i].forward(torch.cat([net_glob.state_dict()[i+'.weight'].reshape(1,-1), net_glob.state_dict()[i+'.bias'].reshape(1,-1)], 1))
        else:
            emb_global += layer_dict[i].forward(torch.cat([net_glob.state_dict()[i+'.weight'].reshape(1,-1), net_glob.state_dict()[i+'.bias'].reshape(1,-1)], 1))
    emb_global = emb_global/4  #(1*100)
    
#拼接
#     for i in layer_name:
#         if i == 'conv1':
#             emb_global = layer_dict[i].forward(torch.cat([net_glob.state_dict()[i+'.weight'].reshape(1,-1), net_glob.state_dict()[i+'.bias'].reshape(1,-1)], 1))
#         else:
#             emb_global = torch.cat([emb_global, layer_dict[i].forward(torch.cat([net_glob.state_dict()[i+'.weight'].reshape(1,-1), net_glob.state_dict()[i+'.bias'].reshape(1,-1)], 1))], 1)
                
    p_emb_collect.append(emb_global)
        
#     #先全部训练一次
#     if iter == 0:
#         for idx in range(100):
#             net_init = CNNMnist(args=args).to(args.device)
#             local = LocalUpdate(args=args, dataset=dataset_train, idxs=dict_users[idx], flag=True)
#             w, loss = local.train(net=copy.deepcopy(net_init).to(args.device))
            
#             for i in layer_name:
#                 if i == 'conv1':
#                     emb_feature = layer_dict[i].forward(torch.cat([w[i+'.weight'].reshape(1,-1), w[i+'.bias'].reshape(1,-1)], 1))
#                 else:
#                     emb_feature += layer_dict[i].forward(torch.cat([w[i+'.weight'].reshape(1,-1), w[i+'.bias'].reshape(1,-1)], 1))
           
#             avg_emb_feature = emb_feature/4
            
    
#             ##########  拼接方式  ##########
# #             for i in layer_name:
# #                 if i == 'conv1':
# #                     cat_emb = layer_dict[i].forward(torch.cat([w[i+'.weight'].reshape(1,-1), w[i+'.bias'].reshape(1,-1)], 1))
# #                 else:
# #                     cat_emb = torch.cat([cat_emb, layer_dict[i].forward(torch.cat([w[i+'.weight'].reshape(1,-1), w[i+'.bias'].reshape(1,-1)], 1))], 1)                    

#             ############ 储存嵌入 ##############
#             p_emb_collect.append(avg_emb_feature)
# #             p_emb_collect.append(cat_emb)
#             ############ 储存参数 ##############
#             w_locals.append(copy.deepcopy(w))
#             w_save.append(w)
# #             net_tmp = CNNMnist(args=args).to(args.device)
# #             net_tmp.load_state_dict(w)
# #             global_acc, loss_train = test_img(net_tmp, dataset_train, args)
# #             print(global_acc)


#     elif iter >= 1:

#之后用global赋值然后训练
    for idx in range(100):
#         if idx == action:
#             local = LocalUpdate(args=args, dataset=dataset_train, idxs=dict_users[idx], flag=True)
#             w, loss = local.train(net=copy.deepcopy(net_glob).to(args.device))
#         else:
#             local = LocalUpdate(args=args, dataset=dataset_train, idxs=dict_users[idx])
#             net_self = CNNMnist(args=args).to(args.device)
#             net_self.load_state_dict(w_save[idx])
#             w, loss = local.train(net=copy.deepcopy(net_self).to(args.device))
        local = LocalUpdate(args=args, dataset=dataset_train, idxs=dict_users[idx], flag=True)
        w, loss = local.train(net=copy.deepcopy(net_glob).to(args.device))
#         if loss != 0:
#             print('local loss:', loss)
    #把parameter转成嵌入

        ##########  求均值方式  #######
        for i in layer_name:
            if i == 'conv1':
                emb_feature = layer_dict[i].forward(torch.cat([w[i+'.weight'].reshape(1,-1), w[i+'.bias'].reshape(1,-1)], 1))
            else:
                emb_feature += layer_dict[i].forward(torch.cat([w[i+'.weight'].reshape(1,-1), w[i+'.bias'].reshape(1,-1)], 1))

#             #分母对应local网络的层数
        avg_emb_feature = emb_feature/4

        ##########  拼接方式  ##########
#             for i in layer_name:
#                 if i == 'conv1':
#                     cat_emb = layer_dict[i].forward(torch.cat([w[i+'.weight'].reshape(1,-1), w[i+'.bias'].reshape(1,-1)], 1))
#                 else:
#                     cat_emb = torch.cat([cat_emb, layer_dict[i].forward(torch.cat([w[i+'.weight'].reshape(1,-1), w[i+'.bias'].reshape(1,-1)], 1))], 1)

        ############ 储存嵌入 ##############
        p_emb_collect.append(avg_emb_feature)
#             p_emb_collect.append(cat_emb)
        ############ 储存参数 ##############
        w_locals.append(copy.deepcopy(w))
    p_emb_collect = torch.squeeze(torch.cat(p_emb_collect,1).unsqueeze(0),1).to(args.device)
#     print(p_emb_collect.shape)
    
    action_next = dqn.choose_action_train(p_emb_collect)# list
    
#     if iter > threshold:
#         action_next = dqn.choose_action_train(p_emb_collect)# list
#     #开始更新之前没必要用错误的网络一直选择一个client，可能污染buffer
#     else:
#         action_next = random.sample(range(0,100),10)
#         print(action_next)

    
    
    
    ###########  计算当前轮的reward，然后将当前轮的reward添加到上一个replay_data中    
    net_glob.eval()
    global_acc, loss_train = test_img(net_glob, dataset_test, args)
#     print(global_acc)
    reward = constant ** (global_acc.numpy()/100 - target_acc) - 1
    if reward >= 0:
        print('well done')
        break
    loss_save.append(reward)
    
    if len(last_replay_data)==3:
#         last_replay_data.append(reward)#r
        last_replay_data.append(p_emb_collect)#s_next
        dqn.replay_buffer.add(last_replay_data[0], last_replay_data[1], last_replay_data[2], last_replay_data[3])
    
    last_replay_data = [p_emb_collect, torch.LongTensor(action_next).unsqueeze(0), reward]#s, a
    action = action_next
    
    # update global weights
    w_chosen = []
    for i in action:
        w_chosen.append(w_locals[i])
        
    w_glob = FedAvg(w_locals)
#     w_glob = FedPareto(w_locals, action_next)
#     w_glob = w_locals[action]
    
    #ours
#     w_glob = FedPareto(w_locals, weight, action)
    #更新w_save
#     w_save[action] = w_glob
        
    # copy weight to net_glob
    net_glob.load_state_dict(w_glob)
    
    
    # print loss
    print('Round {:3d}, Average loss {:.3f}'.format(iter, reward))
#     loss_train.append(-reward)
    
    
    if iter > 30 and iter % 40 == 0:
        beta = beta_by_frame(iter)
        dqn.optimize(beta)
    
    #debug
#     if iter > 10:
#         beta = beta_by_frame(iter)
#         dqn.optimize(beta)
#     if iter > 10:
#         #如果刚好遇上buffer清空，就跳过一次更新
#         if len(dqn.replay_buffer.buffer) == 0:
#             print('buffer=0')
#             continue
#         else:
#             beta = beta_by_frame(iter)
#             dqn.optimize(beta)
    if iter % 500 == 0:
#         torch.save(dqn, 'dqn_{}_{}_{}_clientnumber{}_localep{}.pt'.format(args.dataset, args.epochs, args.model, m, args.local_ep))
        torch.save(dqn, 'dqn_{}_{}_{}_clientnumber{}_localep{}.pt'.format(args.dataset, args.epochs, args.model, m, args.local_ep))
        print("saved")
#     args.lr = max(args.lr*args.lr_decay, 0.001)



Round   0, Average loss -0.976
saved
Round   1, Average loss -0.972
Round   2, Average loss -0.902
Round   3, Average loss -0.664
Round   4, Average loss -0.485
Round   5, Average loss -0.420
Round   6, Average loss -0.388
Round   7, Average loss -0.351
Round   8, Average loss -0.331
Round   9, Average loss -0.295
Round  10, Average loss -0.283
Round  11, Average loss -0.273
Round  12, Average loss -0.266
Round  13, Average loss -0.252
Round  14, Average loss -0.246
Round  15, Average loss -0.240
Round  16, Average loss -0.237
Round  17, Average loss -0.228
Round  18, Average loss -0.219
Round  19, Average loss -0.222
Round  20, Average loss -0.215
Round  21, Average loss -0.221
Round  22, Average loss -0.221
Round  23, Average loss -0.221
Round  24, Average loss -0.223
Round  25, Average loss -0.227
Round  26, Average loss -0.230
Round  27, Average loss -0.222
Round  28, Average loss -0.229
Round  29, Average loss -0.237
Round  30, Average loss -0.218
Round  31, Average loss -0.244
Ro

KeyboardInterrupt: 

In [8]:
# fl
#initialize
if args.model == 'cnn' and args.dataset == 'cifar':
    net_glob = CNNCifar(args=args).to(args.device)
elif args.model == 'cnn' and args.dataset == 'mnist':
    net_glob = CNNMnist(args=args).to(args.device)
elif args.model == 'mlp':
    len_in = 1
    for x in img_size:
        len_in *= x
    net_glob = MLP(dim_in=len_in, dim_hidden=200, dim_out=args.num_classes).to(args.device)
else:
    exit('Error: unrecognized model')
# net_glob.train()
w_glob = net_glob.state_dict()

w_save = []
m = max(int(args.frac * args.num_users), 1)
dqn = torch.load('dqn_{}_{}_{}_clientnumber{}_localep{}.pt'.format(args.dataset, args.epochs, args.model, m, args.local_ep))
constant = 64
target_acc = 0.99
for iter in range(args.validation_epochs):
    if iter == 0:
        random_n = 0
        n_weight = []
        while random_n<args.num_users*args.frac:
            n_weight.append(random.random()) # 随机初始化参数
            random_n+=1
        action = random.sample(range(0,100),10)
        print(action)
        
    loss_locals = []
    w_locals = []
    p_emb_collect = []
    for i in layer_name:
        if i == 'conv1':
            emb_global = layer_dict[i].forward(torch.cat([net_glob.state_dict()[i+'.weight'].reshape(1,-1), net_glob.state_dict()[i+'.bias'].reshape(1,-1)], 1))
        else:
            emb_global += layer_dict[i].forward(torch.cat([net_glob.state_dict()[i+'.weight'].reshape(1,-1), net_glob.state_dict()[i+'.bias'].reshape(1,-1)], 1))
    emb_global = emb_global/4
    p_emb_collect.append(emb_global)
    
    for idx in range(100):
        local = LocalUpdate(args=args, dataset=dataset_train, idxs=dict_users[idx])
        w, loss = local.train(net=copy.deepcopy(net_glob).to(args.device))
        
        ##########  求均值方式  #######
        for i in layer_name:
            if i == 'conv1':
                emb_feature = layer_dict[i].forward(torch.cat([w[i+'.weight'].reshape(1,-1), w[i+'.bias'].reshape(1,-1)], 1))
            else:
                emb_feature += layer_dict[i].forward(torch.cat([w[i+'.weight'].reshape(1,-1), w[i+'.bias'].reshape(1,-1)], 1))

        #分母对应local网络的层数
        avg_emb_feature = emb_feature/4
        
        ############ 储存嵌入 ##############
        p_emb_collect.append(avg_emb_feature)
        ############ 储存参数 ##############
        w_locals.append(copy.deepcopy(w))
        
    p_emb_collect = torch.squeeze(torch.cat(p_emb_collect,1).unsqueeze(0),1).to(args.device)
    action_next = dqn.choose_action_run(p_emb_collect)
    
    action = action_next
    # update global weights
    #fedavg
    chosen_w = []
    for i in action:
        chosen_w.append(w_locals[i])
    w_glob = FedAvg(chosen_w)
    
    #ours
#     w_glob = FedPareto(w_locals, weight, action_next)
    # copy weight to net_glob
    net_glob.load_state_dict(w_glob)
    
    #计算当前轮的reward   
#     net_glob.eval()
    global_acc, loss_train = test_img(net_glob, dataset_train, args)
    reward = constant ** (global_acc.numpy()/100 - target_acc) - 1
    
    # print loss
    print('Round {:3d}, Average loss {:.3f}'.format(iter, -reward))
#     loss_train.append(-reward)

#     args.lr = max(args.lr*args.lr_decay, 0.001)

[75, 50, 99, 74, 43, 71, 69, 78, 20, 70]
[83, 5, 2, 85, 59]
Round   0, Average loss 0.975
[83, 72, 5, 59, 85]
Round   1, Average loss 0.965
[83, 5, 72, 2, 85]
Round   2, Average loss 0.962
[83, 5, 72, 2, 59]
Round   3, Average loss 0.949
[83, 2, 72, 85, 5]
Round   4, Average loss 0.950
[2, 83, 5, 85, 72]
Round   5, Average loss 0.945
[5, 83, 72, 2, 85]
Round   6, Average loss 0.940
[83, 5, 72, 2, 85]
Round   7, Average loss 0.940
[83, 5, 72, 2, 8]
Round   8, Average loss 0.936
[5, 83, 85, 72, 2]
Round   9, Average loss 0.937


In [9]:
# plot loss curve
# plt.figure()
# plt.plot(range(len(loss_train)), loss_train)
# plt.ylabel('train_loss')
# plt.show()
# plt.savefig('./save/fed_{}_{}_{}_C{}_iid{}.png'.format(args.dataset, args.model, args.epochs, args.frac, args.iid))

# testing
net_glob.eval()
acc_train, loss_train = test_img(net_glob, dataset_train, args)
acc_test, loss_test = test_img(net_glob, dataset_test, args)
print("Training accuracy: {:.2f}".format(acc_train))
print("Testing accuracy: {:.2f}".format(acc_test))

Training accuracy: 32.56
Testing accuracy: 32.33


In [None]:
loss_train

In [14]:
loss_save

[-0.975371031855105,
 -0.8102594432248442,
 -0.975580118651963,
 -0.9007438634890736,
 -0.7761971886804484,
 -0.6804087589944856,
 -0.5142451953292079,
 -0.5662005594182116,
 -0.5218947723519878,
 -0.8102594432248442,
 -0.5218947723519878,
 -0.4699501956687363,
 -0.519869068681007,
 -0.9750860083538706,
 -0.519869068681007,
 -0.4451382740224047,
 -0.4999999471170642,
 -0.4421303270368043,
 -0.9750860083538706,
 -0.4421303270368043,
 -0.4680361833106459,
 -0.9750860083538706,
 -0.4680361833106459,
 -0.43590901479795074,
 -0.523581896570082,
 -0.43856139831845153,
 -0.4304082993402708,
 -0.9750860083538706,
 -0.4304082993402708,
 -0.9750860083538706,
 -0.4304082993402708,
 -0.4048603887714919,
 -0.4176332067657721,
 -0.46470721366917545,
 -0.9750860083538706,
 -0.46470721366917545,
 -0.9750860083538706,
 -0.46470721366917545,
 -0.44502294516640284,
 -0.41056687074612686,
 -0.4535733233867021,
 -0.9750860083538706,
 -0.4535733233867021,
 -0.9750860083538706,
 -0.4535733233867021,
 -0.4123

In [1]:
%matplotlib inline

plt.plot(range(len(loss_save)), loss_save)
plt.show

NameError: name 'plt' is not defined