In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Subset, DataLoader
from torchvision import datasets, transforms
import random
import time
import sys

sys.path.append("..")
from utils.tool import GPU_info, write_log, print_label_stat
from models.resnet import ResNet8, ResNet18, ResNet50
from models.mobilenet import MobileNet_S, MobileNet_M, MobileNet_L
from framework.ours import Device
from framework.DSGD import DSGD_Device
from framework.SISA import SISA_Device
from framework.FedAvgUnl import FedAvgServer, FedAvgClient


# one in ['ours', 'DSGD', 'SISA', 'Fed']
Framework = 'Fed'

# one in ['MNIST', 'FMNIST', 'CIFAR10']
DatasetName = 'CIFAR10'

# one in ['resnet', 'mobilenet']
ModelType = 'mobilenet'

# one in [1,0]. 1 - Heterogeneous;  0 - Homogeneous
Heterogeneous = 1

In [2]:
train_set_o = None
test_set_o = None
num_classes = 10
device_num = 5
num_channel = 1
ref_size = 10000
train_test_total_size = int(50000/device_num)
test_ratio = 0.2

train_batch_size = 256 if ModelType == 'resnet' else 32
save_path = './checkpoint'
data_path = '../data'
log_path = '../log/unlearn_{}_{}_{}.txt'.format(Framework, ModelType, DatasetName)

my_seed = 1
torch.cuda.manual_seed(my_seed)
epoch_num = 300 if ModelType == 'resnet' else 100

if DatasetName == 'CIFAR10':
    train_set_o = datasets.CIFAR10(data_path, train=True, download=True)
    test_set_o = datasets.CIFAR10(data_path, train=False, download=True)
    num_channel = 3
    epoch_num = 700 if ModelType == 'resnet' else 200

elif DatasetName == 'MNIST':    
    train_set_o = datasets.MNIST(data_path, train=True, download=True)
    test_set_o = datasets.MNIST(data_path, train=False, download=True)
    epoch_num = 300 if ModelType == 'resnet' else 50

elif DatasetName == 'FMNIST':    
    train_set_o = datasets.FashionMNIST(data_path, train=True, download=True)
    test_set_o = datasets.FashionMNIST(data_path, train=False, download=True)
    epoch_num = 500 if ModelType == 'resnet' else 200

train_set_o.transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,)),
])

test_set_o.transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,)),
])

ref_set = Subset(test_set_o, range(0,int(ref_size)))
ref_loader = torch.utils.data.DataLoader(ref_set, batch_size=train_batch_size,
                                         shuffle=False, num_workers=0)

Files already downloaded and verified
Files already downloaded and verified


In [3]:
# data manipulation
device_dict = {}
loader_dict = {}
print('Device', end='\t')
for class_id in range(num_classes):
    print('Class'+str(class_id), end='\t')
print('SUM')
for device_id in range(device_num):
    range_start = train_test_total_size * device_id
    range_end = range_start + train_test_total_size
    # remove one class from each local dataset
    class_to_remove = torch.tensor(device_id%10)
    indices = (torch.tensor(train_set_o.targets[range_start:range_end])[..., None] !=
               class_to_remove).any(-1).nonzero(as_tuple=True)[0]
    # split train&test
    train_test_border = int((1-test_ratio)*len(indices))
    train_set = Subset(train_set_o, indices[:train_test_border]+range_start)
    test_set = Subset(train_set_o, indices[train_test_border:]+range_start)
    train_loader = DataLoader(train_set, batch_size=train_batch_size,
                              shuffle=True, num_workers=0)
    test_loader = DataLoader(test_set, batch_size=train_batch_size,
                             shuffle=True, num_workers=0) 
    loader_dict[device_id] = [train_loader, test_loader]
    print_label_stat(device_id, train_set, num_classes)

Device	Class0	Class1	Class2	Class3	Class4	Class5	Class6	Class7	Class8	Class9	SUM
D-0	0	764	817	807	806	765	830	809	807	791	7196
D-1	793	0	817	806	802	768	813	824	782	789	7194
D-2	798	807	0	817	790	815	787	806	781	827	7228
D-3	790	770	852	0	809	824	803	781	802	788	7219
D-4	822	816	770	806	0	831	778	778	788	813	7202


In [4]:
if Heterogeneous == 1:
    # heterogeneous scenario
    for device_id in range(device_num):
        gpu_id = 0
        device_dict[device_id] = FedAvgClient(device_id, gpu_id=gpu_id, num_classes=num_classes, num_channel=num_channel)
        device_dict[device_id].client_model = ResNet8(num_channel).cuda(gpu_id) if ModelType == 'resnet' else MobileNet_S(num_channel).cuda(gpu_id)
        device_dict[device_id].optimizer = optim.Adam(device_dict[device_id].client_model.parameters(), lr=0.01)
    Fed_server = FedAvgServer(Device_id_list = device_dict.keys(), gpu_id=gpu_id, num_classes=num_classes, num_channel=num_channel)
    Fed_server.central_model = ResNet8(num_channel).cuda(gpu_id) if ModelType == 'resnet' else MobileNet_S(num_channel).cuda(gpu_id)
        
else:
    # homogeneous scenario
    for device_id in range(device_num):
        gpu_id = 0
        device_dict[device_id] = FedAvgClient(device_id, gpu_id=gpu_id, num_classes=num_classes, num_channel=num_channel)
        device_dict[device_id].client_model = ResNet50(num_channel).cuda(gpu_id) if ModelType == 'resnet' else MobileNet_L(num_channel).cuda(gpu_id)
        device_dict[device_id].optimizer = optim.Adam(device_dict[device_id].client_model.parameters(), lr=0.01)
    Fed_server = FedAvgServer(Device_id_list = device_dict.keys(), gpu_id=gpu_id, num_classes=num_classes, num_channel=num_channel)
    Fed_server.central_model = ResNet50(num_channel).cuda(gpu_id) if ModelType == 'resnet' else MobileNet_L(num_channel).cuda(gpu_id)

In [5]:
recover_id = 4
device_back_up = device_dict[recover_id]
print(device_dict)
del device_dict[recover_id]
print(device_dict)

{0: <framework.FedAvgUnl.FedAvgClient object at 0x7fb7b919cf90>, 1: <framework.FedAvgUnl.FedAvgClient object at 0x7fb7b9111310>, 2: <framework.FedAvgUnl.FedAvgClient object at 0x7fb7b8d890d0>, 3: <framework.FedAvgUnl.FedAvgClient object at 0x7fb7b90ef850>, 4: <framework.FedAvgUnl.FedAvgClient object at 0x7fb7b89b1210>}
{0: <framework.FedAvgUnl.FedAvgClient object at 0x7fb7b919cf90>, 1: <framework.FedAvgUnl.FedAvgClient object at 0x7fb7b9111310>, 2: <framework.FedAvgUnl.FedAvgClient object at 0x7fb7b8d890d0>, 3: <framework.FedAvgUnl.FedAvgClient object at 0x7fb7b90ef850>}


In [6]:
# train local main models
write_log(log_path, time.ctime(time.time()))
for epoch in range(epoch_num):
    for client_id, client in device_dict.items():
        client.update_client_model(num_iter=3, local_loader=loader_dict[client_id][0])
    Fed_server.store_client_historical_update(device_dict)
    Fed_server.update_central_model(frac=1, device_dict=device_dict)
    Fed_server.distribute_central_model(device_dict=device_dict)
    metric = []
    for client_id, client in device_dict.items():
        metric.append(Fed_server.validate_central_model(test_loader=loader_dict[client_id][1]))
    metric_arr=np.array(metric)
    log_txt = '{:.4f}\t{:.4f}\t{:.4f}\t{:.4f}\t'.format(np.mean(metric_arr, axis=0)[0], 
                                      np.mean(metric_arr, axis=0)[1],np.mean(metric_arr, axis=0)[2],
                                      np.mean(metric_arr, axis=0)[3])
    print('Epoch:{} mean: {}'.format(epoch ,log_txt), end='\n\n')
    write_log(log_path, log_txt)

Client 0	Epoch:  0		Loss: 1.72305548
Client 1	Epoch:  0		Loss: 1.94414341
Client 2	Epoch:  0		Loss: 1.65297520
Client 3	Epoch:  0		Loss: 1.57459354
Server - Avg_loss: 0.0742, Acc: 172.0/1799 (0.0956)
Server - Avg_loss: 0.0729, Acc: 220.0/1799 (0.1223)
Server - Avg_loss: 0.0750, Acc: 214.0/1807 (0.1184)
Server - Avg_loss: 0.0737, Acc: 197.0/1805 (0.1091)
Epoch:0 mean: 0.1114	0.0128	0.1116	0.0225	

Client 0	Epoch:  0		Loss: 1.81485045
Client 1	Epoch:  0		Loss: 1.75699914
Client 2	Epoch:  0		Loss: 1.41465676
Client 3	Epoch:  0		Loss: 1.64645815
Server - Avg_loss: 0.1178, Acc: 210.0/1799 (0.1167)
Server - Avg_loss: 0.1041, Acc: 231.0/1799 (0.1284)
Server - Avg_loss: 0.1290, Acc: 233.0/1807 (0.1289)
Server - Avg_loss: 0.1194, Acc: 238.0/1805 (0.1319)
Epoch:1 mean: 0.1265	0.0364	0.1245	0.0483	

Client 0	Epoch:  0		Loss: 1.22578466
Client 1	Epoch:  0		Loss: 1.14464283
Client 2	Epoch:  0		Loss: 1.37774158
Client 3	Epoch:  0		Loss: 1.57488322
Server - Avg_loss: 0.0955, Acc: 326.0/1799 (0.1812)


Server - Avg_loss: 0.0445, Acc: 1019.0/1799 (0.5664)
Server - Avg_loss: 0.0434, Acc: 1047.0/1799 (0.5820)
Server - Avg_loss: 0.0394, Acc: 1064.0/1807 (0.5888)
Server - Avg_loss: 0.0378, Acc: 1115.0/1805 (0.6177)
Epoch:20 mean: 0.5887	0.5315	0.5781	0.5011	

Client 0	Epoch:  0		Loss: 0.73182422
Client 1	Epoch:  0		Loss: 1.00935531
Client 2	Epoch:  0		Loss: 0.56005722
Client 3	Epoch:  0		Loss: 0.88213980
Server - Avg_loss: 0.0424, Acc: 1042.0/1799 (0.5792)
Server - Avg_loss: 0.0426, Acc: 1045.0/1799 (0.5809)
Server - Avg_loss: 0.0388, Acc: 1056.0/1807 (0.5844)
Server - Avg_loss: 0.0392, Acc: 1104.0/1805 (0.6116)
Epoch:21 mean: 0.5890	0.5265	0.5713	0.4950	

Client 0	Epoch:  0		Loss: 0.80794650
Client 1	Epoch:  0		Loss: 1.00793517
Client 2	Epoch:  0		Loss: 0.75246209
Client 3	Epoch:  0		Loss: 0.82102716
Server - Avg_loss: 0.0425, Acc: 1052.0/1799 (0.5848)
Server - Avg_loss: 0.0426, Acc: 1078.0/1799 (0.5992)
Server - Avg_loss: 0.0387, Acc: 1091.0/1807 (0.6038)
Server - Avg_loss: 0.0370, Acc:

Server - Avg_loss: 0.0347, Acc: 1169.0/1807 (0.6469)
Server - Avg_loss: 0.0333, Acc: 1208.0/1805 (0.6693)
Epoch:40 mean: 0.6321	0.5696	0.6103	0.5472	

Client 0	Epoch:  0		Loss: 0.58248997
Client 1	Epoch:  0		Loss: 0.83307087
Client 2	Epoch:  0		Loss: 0.47780451
Client 3	Epoch:  0		Loss: 0.69833559
Server - Avg_loss: 0.0398, Acc: 1094.0/1799 (0.6081)
Server - Avg_loss: 0.0404, Acc: 1081.0/1799 (0.6009)
Server - Avg_loss: 0.0363, Acc: 1139.0/1807 (0.6303)
Server - Avg_loss: 0.0342, Acc: 1176.0/1805 (0.6515)
Epoch:41 mean: 0.6227	0.5714	0.6030	0.5406	

Client 0	Epoch:  0		Loss: 0.82169056
Client 1	Epoch:  0		Loss: 0.65578866
Client 2	Epoch:  0		Loss: 0.82048273
Client 3	Epoch:  0		Loss: 0.63608509
Server - Avg_loss: 0.0406, Acc: 1073.0/1799 (0.5964)
Server - Avg_loss: 0.0408, Acc: 1073.0/1799 (0.5964)
Server - Avg_loss: 0.0365, Acc: 1137.0/1807 (0.6292)
Server - Avg_loss: 0.0345, Acc: 1175.0/1805 (0.6510)
Epoch:42 mean: 0.6183	0.5757	0.6099	0.5417	

Client 0	Epoch:  0		Loss: 0.79779541
Cl

Client 0	Epoch:  0		Loss: 0.49540493
Client 1	Epoch:  0		Loss: 1.02283406
Client 2	Epoch:  0		Loss: 0.48953310
Client 3	Epoch:  0		Loss: 0.69086838
Server - Avg_loss: 0.0384, Acc: 1104.0/1799 (0.6137)
Server - Avg_loss: 0.0385, Acc: 1087.0/1799 (0.6042)
Server - Avg_loss: 0.0337, Acc: 1168.0/1807 (0.6464)
Server - Avg_loss: 0.0342, Acc: 1213.0/1805 (0.6720)
Epoch:61 mean: 0.6341	0.5863	0.6077	0.5492	

Client 0	Epoch:  0		Loss: 0.98124713
Client 1	Epoch:  0		Loss: 0.72069442
Client 2	Epoch:  0		Loss: 1.08542609
Client 3	Epoch:  0		Loss: 0.41118106
Server - Avg_loss: 0.0410, Acc: 1079.0/1799 (0.5998)
Server - Avg_loss: 0.0404, Acc: 1084.0/1799 (0.6026)
Server - Avg_loss: 0.0363, Acc: 1146.0/1807 (0.6342)
Server - Avg_loss: 0.0354, Acc: 1191.0/1805 (0.6598)
Epoch:62 mean: 0.6241	0.5648	0.6089	0.5386	

Client 0	Epoch:  0		Loss: 0.67916048
Client 1	Epoch:  0		Loss: 0.92129016
Client 2	Epoch:  0		Loss: 0.74787408
Client 3	Epoch:  0		Loss: 0.58070999
Server - Avg_loss: 0.0406, Acc: 1097.0/179

Client 3	Epoch:  0		Loss: 0.55816615
Server - Avg_loss: 0.0394, Acc: 1108.0/1799 (0.6159)
Server - Avg_loss: 0.0371, Acc: 1125.0/1799 (0.6253)
Server - Avg_loss: 0.0344, Acc: 1165.0/1807 (0.6447)
Server - Avg_loss: 0.0328, Acc: 1219.0/1805 (0.6753)
Epoch:81 mean: 0.6403	0.5877	0.6177	0.5580	

Client 0	Epoch:  0		Loss: 1.02452004
Client 1	Epoch:  0		Loss: 0.70777416
Client 2	Epoch:  0		Loss: 0.48362175
Client 3	Epoch:  0		Loss: 1.05679727
Server - Avg_loss: 0.0404, Acc: 1089.0/1799 (0.6053)
Server - Avg_loss: 0.0407, Acc: 1084.0/1799 (0.6026)
Server - Avg_loss: 0.0351, Acc: 1145.0/1807 (0.6336)
Server - Avg_loss: 0.0355, Acc: 1197.0/1805 (0.6632)
Epoch:82 mean: 0.6262	0.5770	0.6077	0.5440	

Client 0	Epoch:  0		Loss: 0.84929287
Client 1	Epoch:  0		Loss: 0.55060428
Client 2	Epoch:  0		Loss: 0.73354113
Client 3	Epoch:  0		Loss: 0.85103244
Server - Avg_loss: 0.0401, Acc: 1103.0/1799 (0.6131)
Server - Avg_loss: 0.0389, Acc: 1106.0/1799 (0.6148)
Server - Avg_loss: 0.0340, Acc: 1177.0/1807 (0.

Server - Avg_loss: 0.0339, Acc: 1196.0/1807 (0.6619)
Server - Avg_loss: 0.0323, Acc: 1235.0/1805 (0.6842)
Epoch:101 mean: 0.6478	0.5909	0.6241	0.5659	

Client 0	Epoch:  0		Loss: 0.99788153
Client 1	Epoch:  0		Loss: 0.87888169
Client 2	Epoch:  0		Loss: 0.41695124
Client 3	Epoch:  0		Loss: 0.99117178
Server - Avg_loss: 0.0408, Acc: 1100.0/1799 (0.6115)
Server - Avg_loss: 0.0395, Acc: 1100.0/1799 (0.6115)
Server - Avg_loss: 0.0351, Acc: 1169.0/1807 (0.6469)
Server - Avg_loss: 0.0349, Acc: 1215.0/1805 (0.6731)
Epoch:102 mean: 0.6357	0.5804	0.6170	0.5526	

Client 0	Epoch:  0		Loss: 0.97952425
Client 1	Epoch:  0		Loss: 0.59608191
Client 2	Epoch:  0		Loss: 0.61277360
Client 3	Epoch:  0		Loss: 0.24272262
Server - Avg_loss: 0.0406, Acc: 1093.0/1799 (0.6076)
Server - Avg_loss: 0.0388, Acc: 1115.0/1799 (0.6198)
Server - Avg_loss: 0.0362, Acc: 1148.0/1807 (0.6353)
Server - Avg_loss: 0.0344, Acc: 1215.0/1805 (0.6731)
Epoch:103 mean: 0.6339	0.5803	0.6112	0.5511	

Client 0	Epoch:  0		Loss: 0.56278712

Client 0	Epoch:  0		Loss: 0.59690291
Client 1	Epoch:  0		Loss: 1.05101752
Client 2	Epoch:  0		Loss: 0.98886240
Client 3	Epoch:  0		Loss: 0.96036518
Server - Avg_loss: 0.0399, Acc: 1134.0/1799 (0.6304)
Server - Avg_loss: 0.0383, Acc: 1097.0/1799 (0.6098)
Server - Avg_loss: 0.0327, Acc: 1196.0/1807 (0.6619)
Server - Avg_loss: 0.0323, Acc: 1242.0/1805 (0.6881)
Epoch:122 mean: 0.6475	0.5967	0.6265	0.5697	

Client 0	Epoch:  0		Loss: 0.67258722
Client 1	Epoch:  0		Loss: 0.51986229
Client 2	Epoch:  0		Loss: 0.47377905
Client 3	Epoch:  0		Loss: 0.79903513
Server - Avg_loss: 0.0395, Acc: 1112.0/1799 (0.6181)
Server - Avg_loss: 0.0387, Acc: 1105.0/1799 (0.6142)
Server - Avg_loss: 0.0343, Acc: 1169.0/1807 (0.6469)
Server - Avg_loss: 0.0341, Acc: 1225.0/1805 (0.6787)
Epoch:123 mean: 0.6395	0.5952	0.6163	0.5600	

Client 0	Epoch:  0		Loss: 0.69677776
Client 1	Epoch:  0		Loss: 0.79022217
Client 2	Epoch:  0		Loss: 0.35531560
Client 3	Epoch:  0		Loss: 0.71135634
Server - Avg_loss: 0.0378, Acc: 1141.0/1

Client 2	Epoch:  0		Loss: 0.32721427
Client 3	Epoch:  0		Loss: 0.50321597
Server - Avg_loss: 0.0397, Acc: 1125.0/1799 (0.6253)
Server - Avg_loss: 0.0389, Acc: 1117.0/1799 (0.6209)
Server - Avg_loss: 0.0334, Acc: 1198.0/1807 (0.6630)
Server - Avg_loss: 0.0333, Acc: 1228.0/1805 (0.6803)
Epoch:142 mean: 0.6474	0.5957	0.6253	0.5687	

Client 0	Epoch:  0		Loss: 0.62193674
Client 1	Epoch:  0		Loss: 0.77261567


KeyboardInterrupt: 

In [7]:
# unlearn one client
unlearn_id = 3
Fed_server.unlearn_client(unlearn_id)
write_log(log_path, '\n unlearn_id: {}'.format(unlearn_id))
del device_dict[unlearn_id]
print(device_dict)
metric = []
for client_id, client in device_dict.items():
    metric.append(Fed_server.validate_central_model(test_loader=loader_dict[client_id][1]))
metric_arr=np.array(metric)
log_txt = '{:.4f}\t{:.4f}\t{:.4f}\t{:.4f}\t'.format(np.mean(metric_arr, axis=0)[0], 
                                  np.mean(metric_arr, axis=0)[1],np.mean(metric_arr, axis=0)[2],
                                  np.mean(metric_arr, axis=0)[3])
print('mean: {}'.format(log_txt))
write_log(log_path, log_txt)
Fed_server.update_soft_label(ref_loader)
Fed_server.optimizer = optim.Adam(Fed_server.central_model.parameters(), lr=0.01)

{0: <framework.FedAvgUnl.FedAvgClient object at 0x7fb7b919cf90>, 1: <framework.FedAvgUnl.FedAvgClient object at 0x7fb7b9111310>, 2: <framework.FedAvgUnl.FedAvgClient object at 0x7fb7b8d890d0>}
Server - Avg_loss: nan, Acc: 0.0/1799 (0.0000)
Server - Avg_loss: nan, Acc: 191.0/1799 (0.1062)
Server - Avg_loss: nan, Acc: 196.0/1807 (0.1085)
mean: 0.0715	0.0081	0.0735	0.0143	


In [8]:
# remedy central model
epoch_num = 120
for epoch in range(epoch_num):
    Fed_server.remedy_central_model(ref_set, training_time=60, batch_size=2000)
    metric = []
    for client_id, client in device_dict.items():
        metric.append(Fed_server.validate_central_model(test_loader=loader_dict[client_id][1]))
    metric_arr=np.array(metric)
    log_txt = '{:.4f}\t{:.4f}\t{:.4f}\t{:.4f}\t'.format(np.mean(metric_arr, axis=0)[0], 
                                      np.mean(metric_arr, axis=0)[1],np.mean(metric_arr, axis=0)[2],
                                      np.mean(metric_arr, axis=0)[3])
    print('Epoch:{} mean: {}'.format(epoch ,log_txt), end='\n\n')
    write_log(log_path, log_txt)

remedy_central_model - Epoch:   0 	Loss: 642.935303
remedy_central_model - Epoch:  50 	Loss: 550.470825
remedy_central_model - Epoch: 100 	Loss: 468.665588
remedy_central_model - Epoch: 150 	Loss: 388.066895
remedy_central_model - Epoch: 200 	Loss: 311.122833
remedy_central_model - Epoch: 250 	Loss: 265.903656
remedy_central_model - Epoch: 300 	Loss: 230.074646
remedy_central_model - Epoch: 350 	Loss: 172.611359
remedy_central_model - Epoch: 400 	Loss: 147.834290
remedy_central_model - Epoch: 450 	Loss: 134.394928
remedy_central_model - Epoch: 500 	Loss: 124.205399
remedy time consumption(s): 60.038095235824585
Server - Avg_loss: 11.2139, Acc: 209.0/1799 (0.1162)
Server - Avg_loss: 11.5612, Acc: 189.0/1799 (0.1051)
Server - Avg_loss: 11.4591, Acc: 180.0/1807 (0.0996)
Epoch:0 mean: 0.1069	0.0122	0.1110	0.0217	

remedy_central_model - Epoch:   0 	Loss: 121.822586
remedy_central_model - Epoch:  50 	Loss: 113.873787
remedy_central_model - Epoch: 100 	Loss: 107.078339
remedy_central_model -

remedy_central_model - Epoch: 100 	Loss: 26.220348
remedy_central_model - Epoch: 150 	Loss: 24.839853
remedy_central_model - Epoch: 200 	Loss: 23.511904
remedy_central_model - Epoch: 250 	Loss: 22.120483
remedy_central_model - Epoch: 300 	Loss: 21.174021
remedy_central_model - Epoch: 350 	Loss: 19.803638
remedy_central_model - Epoch: 400 	Loss: 18.476614
remedy_central_model - Epoch: 450 	Loss: 17.406681
remedy_central_model - Epoch: 500 	Loss: 16.022285
remedy time consumption(s): 60.079994916915894
Server - Avg_loss: 1.7246, Acc: 209.0/1799 (0.1162)
Server - Avg_loss: 1.6810, Acc: 189.0/1799 (0.1051)
Server - Avg_loss: 1.7374, Acc: 180.0/1807 (0.0996)
Epoch:10 mean: 0.1069	0.0122	0.1128	0.0215	

remedy_central_model - Epoch:   0 	Loss: 15.802247
remedy_central_model - Epoch:  50 	Loss: 14.476057
remedy_central_model - Epoch: 100 	Loss: 13.073749
remedy_central_model - Epoch: 150 	Loss: 12.002485
remedy_central_model - Epoch: 200 	Loss: 10.838688
remedy_central_model - Epoch: 250 	Los

remedy_central_model - Epoch: 350 	Loss: 0.677180
remedy_central_model - Epoch: 400 	Loss: 0.664279
remedy_central_model - Epoch: 450 	Loss: 0.657458
remedy_central_model - Epoch: 500 	Loss: 0.629622
remedy time consumption(s): 60.045448303222656
Server - Avg_loss: 0.0814, Acc: 200.0/1799 (0.1112)
Server - Avg_loss: 0.0794, Acc: 195.0/1799 (0.1084)
Server - Avg_loss: 0.0803, Acc: 191.0/1807 (0.1057)
Epoch:20 mean: 0.1084	0.0124	0.1112	0.0219	

remedy_central_model - Epoch:   0 	Loss: 0.654652
remedy_central_model - Epoch:  50 	Loss: 0.661065
remedy_central_model - Epoch: 100 	Loss: 0.648445
remedy_central_model - Epoch: 150 	Loss: 0.662570
remedy_central_model - Epoch: 200 	Loss: 0.658228
remedy_central_model - Epoch: 250 	Loss: 0.639968
remedy_central_model - Epoch: 300 	Loss: 0.662875
remedy_central_model - Epoch: 350 	Loss: 0.662746
remedy_central_model - Epoch: 400 	Loss: 0.659095
remedy_central_model - Epoch: 450 	Loss: 0.647799
remedy_central_model - Epoch: 500 	Loss: 0.652420
re

Server - Avg_loss: 0.0816, Acc: 193.0/1799 (0.1073)
Server - Avg_loss: 0.0797, Acc: 208.0/1799 (0.1156)
Server - Avg_loss: 0.0805, Acc: 200.0/1807 (0.1107)
Epoch:30 mean: 0.1112	0.0127	0.1121	0.0223	

remedy_central_model - Epoch:   0 	Loss: 0.641660
remedy_central_model - Epoch:  50 	Loss: 0.660222
remedy_central_model - Epoch: 100 	Loss: 0.642070
remedy_central_model - Epoch: 150 	Loss: 0.644472
remedy_central_model - Epoch: 200 	Loss: 0.657073
remedy_central_model - Epoch: 250 	Loss: 0.656134
remedy_central_model - Epoch: 300 	Loss: 0.655194
remedy_central_model - Epoch: 350 	Loss: 0.651412
remedy_central_model - Epoch: 400 	Loss: 0.653390
remedy_central_model - Epoch: 450 	Loss: 0.665576
remedy_central_model - Epoch: 500 	Loss: 0.655343
remedy time consumption(s): 60.10364031791687
Server - Avg_loss: 0.0817, Acc: 200.0/1799 (0.1112)
Server - Avg_loss: 0.0796, Acc: 195.0/1799 (0.1084)
Server - Avg_loss: 0.0806, Acc: 191.0/1807 (0.1057)
Epoch:31 mean: 0.1084	0.0123	0.1103	0.0218	

re

remedy_central_model - Epoch:  50 	Loss: 0.648705
remedy_central_model - Epoch: 100 	Loss: 0.653282
remedy_central_model - Epoch: 150 	Loss: 0.654846
remedy_central_model - Epoch: 200 	Loss: 0.643736
remedy_central_model - Epoch: 250 	Loss: 0.655276
remedy_central_model - Epoch: 300 	Loss: 0.654253
remedy_central_model - Epoch: 350 	Loss: 0.636529
remedy_central_model - Epoch: 400 	Loss: 0.671252
remedy_central_model - Epoch: 450 	Loss: 0.671965
remedy_central_model - Epoch: 500 	Loss: 0.656717
remedy time consumption(s): 60.02760100364685
Server - Avg_loss: 0.0810, Acc: 189.0/1799 (0.1051)
Server - Avg_loss: 0.0792, Acc: 218.0/1799 (0.1212)
Server - Avg_loss: 0.0804, Acc: 202.0/1807 (0.1118)
Epoch:41 mean: 0.1127	0.0251	0.1107	0.0338	

remedy_central_model - Epoch:   0 	Loss: 0.662689
remedy_central_model - Epoch:  50 	Loss: 0.659009
remedy_central_model - Epoch: 100 	Loss: 0.654303
remedy_central_model - Epoch: 150 	Loss: 0.654960
remedy_central_model - Epoch: 200 	Loss: 0.652510
rem

remedy_central_model - Epoch: 300 	Loss: 0.440301
remedy_central_model - Epoch: 350 	Loss: 0.436775
remedy_central_model - Epoch: 400 	Loss: 0.455510
remedy_central_model - Epoch: 450 	Loss: 0.442327
remedy_central_model - Epoch: 500 	Loss: 0.432455
remedy time consumption(s): 60.07043814659119
Server - Avg_loss: 0.0713, Acc: 469.0/1799 (0.2607)
Server - Avg_loss: 0.0696, Acc: 469.0/1799 (0.2607)
Server - Avg_loss: 0.0657, Acc: 438.0/1807 (0.2424)
Epoch:51 mean: 0.2546	0.1753	0.2525	0.1810	

remedy_central_model - Epoch:   0 	Loss: 0.433776
remedy_central_model - Epoch:  50 	Loss: 0.444845
remedy_central_model - Epoch: 100 	Loss: 0.431353
remedy_central_model - Epoch: 150 	Loss: 0.436843
remedy_central_model - Epoch: 200 	Loss: 0.425657
remedy_central_model - Epoch: 250 	Loss: 0.437952
remedy_central_model - Epoch: 300 	Loss: 0.431749
remedy_central_model - Epoch: 350 	Loss: 0.417105
remedy_central_model - Epoch: 400 	Loss: 0.413209
remedy_central_model - Epoch: 450 	Loss: 0.411576
rem

remedy time consumption(s): 60.076923847198486
Server - Avg_loss: 0.0699, Acc: 523.0/1799 (0.2907)
Server - Avg_loss: 0.0714, Acc: 531.0/1799 (0.2952)
Server - Avg_loss: 0.0642, Acc: 522.0/1807 (0.2889)
Epoch:61 mean: 0.2916	0.1892	0.2844	0.1954	

remedy_central_model - Epoch:   0 	Loss: 0.339415
remedy_central_model - Epoch:  50 	Loss: 0.349523
remedy_central_model - Epoch: 100 	Loss: 0.348984
remedy_central_model - Epoch: 150 	Loss: 0.354004
remedy_central_model - Epoch: 200 	Loss: 0.356340
remedy_central_model - Epoch: 250 	Loss: 0.342491
remedy_central_model - Epoch: 300 	Loss: 0.353364
remedy_central_model - Epoch: 350 	Loss: 0.347685
remedy_central_model - Epoch: 400 	Loss: 0.336057
remedy_central_model - Epoch: 450 	Loss: 0.335632
remedy_central_model - Epoch: 500 	Loss: 0.333511
remedy time consumption(s): 60.0441689491272
Server - Avg_loss: 0.0696, Acc: 510.0/1799 (0.2835)
Server - Avg_loss: 0.0729, Acc: 554.0/1799 (0.3079)
Server - Avg_loss: 0.0624, Acc: 570.0/1807 (0.3154)
E

remedy_central_model - Epoch:  50 	Loss: 0.306228
remedy_central_model - Epoch: 100 	Loss: 0.314837
remedy_central_model - Epoch: 150 	Loss: 0.306444
remedy_central_model - Epoch: 200 	Loss: 0.316722
remedy_central_model - Epoch: 250 	Loss: 0.312899
remedy_central_model - Epoch: 300 	Loss: 0.309561
remedy_central_model - Epoch: 350 	Loss: 0.312656
remedy_central_model - Epoch: 400 	Loss: 0.305741
remedy_central_model - Epoch: 450 	Loss: 0.313184
remedy_central_model - Epoch: 500 	Loss: 0.314065
remedy time consumption(s): 60.01922607421875
Server - Avg_loss: 0.0825, Acc: 504.0/1799 (0.2802)
Server - Avg_loss: 0.0861, Acc: 533.0/1799 (0.2963)
Server - Avg_loss: 0.0763, Acc: 532.0/1807 (0.2944)
Epoch:72 mean: 0.2903	0.1996	0.2872	0.1987	

remedy_central_model - Epoch:   0 	Loss: 0.315850
remedy_central_model - Epoch:  50 	Loss: 0.301325
remedy_central_model - Epoch: 100 	Loss: 0.306446
remedy_central_model - Epoch: 150 	Loss: 0.307995
remedy_central_model - Epoch: 200 	Loss: 0.307565
rem

remedy_central_model - Epoch: 300 	Loss: 0.294154
remedy_central_model - Epoch: 350 	Loss: 0.287045
remedy_central_model - Epoch: 400 	Loss: 0.297531
remedy_central_model - Epoch: 450 	Loss: 0.292974
remedy_central_model - Epoch: 500 	Loss: 0.296705
remedy time consumption(s): 60.08557629585266
Server - Avg_loss: 0.0718, Acc: 570.0/1799 (0.3168)
Server - Avg_loss: 0.0697, Acc: 609.0/1799 (0.3385)
Server - Avg_loss: 0.0663, Acc: 591.0/1807 (0.3271)
Epoch:82 mean: 0.3275	0.2390	0.3315	0.2459	

remedy_central_model - Epoch:   0 	Loss: 0.297156
remedy_central_model - Epoch:  50 	Loss: 0.296382
remedy_central_model - Epoch: 100 	Loss: 0.295910
remedy_central_model - Epoch: 150 	Loss: 0.297112
remedy_central_model - Epoch: 200 	Loss: 0.299148
remedy_central_model - Epoch: 250 	Loss: 0.297054
remedy_central_model - Epoch: 300 	Loss: 0.293142
remedy_central_model - Epoch: 350 	Loss: 0.299675
remedy_central_model - Epoch: 400 	Loss: 0.305174
remedy_central_model - Epoch: 450 	Loss: 0.300386
rem

remedy time consumption(s): 60.09761452674866
Server - Avg_loss: 0.0666, Acc: 603.0/1799 (0.3352)
Server - Avg_loss: 0.0675, Acc: 627.0/1799 (0.3485)
Server - Avg_loss: 0.0609, Acc: 629.0/1807 (0.3481)
Epoch:92 mean: 0.3439	0.2485	0.3376	0.2498	

remedy_central_model - Epoch:   0 	Loss: 0.279498
remedy_central_model - Epoch:  50 	Loss: 0.285493
remedy_central_model - Epoch: 100 	Loss: 0.289242
remedy_central_model - Epoch: 150 	Loss: 0.280161
remedy_central_model - Epoch: 200 	Loss: 0.281522
remedy_central_model - Epoch: 250 	Loss: 0.284003
remedy_central_model - Epoch: 300 	Loss: 0.285112
remedy_central_model - Epoch: 350 	Loss: 0.283754
remedy_central_model - Epoch: 400 	Loss: 0.282887
remedy_central_model - Epoch: 450 	Loss: 0.279645
remedy_central_model - Epoch: 500 	Loss: 0.285124
remedy time consumption(s): 60.038161516189575
Server - Avg_loss: 0.0662, Acc: 617.0/1799 (0.3430)
Server - Avg_loss: 0.0670, Acc: 613.0/1799 (0.3407)
Server - Avg_loss: 0.0614, Acc: 619.0/1807 (0.3426)


remedy_central_model - Epoch:  50 	Loss: 0.275522
remedy_central_model - Epoch: 100 	Loss: 0.273160
remedy_central_model - Epoch: 150 	Loss: 0.281589
remedy_central_model - Epoch: 200 	Loss: 0.285831
remedy_central_model - Epoch: 250 	Loss: 0.271620
remedy_central_model - Epoch: 300 	Loss: 0.279225
remedy_central_model - Epoch: 350 	Loss: 0.281251
remedy_central_model - Epoch: 400 	Loss: 0.268633
remedy_central_model - Epoch: 450 	Loss: 0.278180
remedy_central_model - Epoch: 500 	Loss: 0.280927
remedy time consumption(s): 60.07167935371399
Server - Avg_loss: 0.0681, Acc: 610.0/1799 (0.3391)
Server - Avg_loss: 0.0659, Acc: 642.0/1799 (0.3569)
Server - Avg_loss: 0.0635, Acc: 614.0/1807 (0.3398)
Epoch:103 mean: 0.3452	0.2481	0.3430	0.2558	

remedy_central_model - Epoch:   0 	Loss: 0.281443
remedy_central_model - Epoch:  50 	Loss: 0.273683
remedy_central_model - Epoch: 100 	Loss: 0.279880
remedy_central_model - Epoch: 150 	Loss: 0.274249
remedy_central_model - Epoch: 200 	Loss: 0.276116
re

remedy_central_model - Epoch: 300 	Loss: 0.282202
remedy_central_model - Epoch: 350 	Loss: 0.272584
remedy_central_model - Epoch: 400 	Loss: 0.282917
remedy_central_model - Epoch: 450 	Loss: 0.274860
remedy_central_model - Epoch: 500 	Loss: 0.270189
remedy time consumption(s): 60.044762134552
Server - Avg_loss: 0.0683, Acc: 598.0/1799 (0.3324)
Server - Avg_loss: 0.0683, Acc: 632.0/1799 (0.3513)
Server - Avg_loss: 0.0629, Acc: 641.0/1807 (0.3547)
Epoch:113 mean: 0.3461	0.2460	0.3374	0.2539	

remedy_central_model - Epoch:   0 	Loss: 0.285125
remedy_central_model - Epoch:  50 	Loss: 0.270915
remedy_central_model - Epoch: 100 	Loss: 0.276868
remedy_central_model - Epoch: 150 	Loss: 0.273103
remedy_central_model - Epoch: 200 	Loss: 0.270902
remedy_central_model - Epoch: 250 	Loss: 0.283100
remedy_central_model - Epoch: 300 	Loss: 0.264118
remedy_central_model - Epoch: 350 	Loss: 0.266548
remedy_central_model - Epoch: 400 	Loss: 0.295053
remedy_central_model - Epoch: 450 	Loss: 0.268950
reme

In [9]:
# train a new client
device_back_up.update_client_model(num_iter=100, local_loader=loader_dict[recover_id][0])
device_dict[recover_id] = device_back_up

Client 4	Epoch:  0		Loss: 2.17156196
Client 4	Epoch: 10		Loss: 1.20499730
Client 4	Epoch: 20		Loss: 1.03735507
Client 4	Epoch: 30		Loss: 1.10371828
Client 4	Epoch: 40		Loss: 0.74466300
Client 4	Epoch: 50		Loss: 0.65222019
Client 4	Epoch: 60		Loss: 0.78575552
Client 4	Epoch: 70		Loss: 0.64386511
Client 4	Epoch: 80		Loss: 0.72812331
Client 4	Epoch: 90		Loss: 0.80482370


In [10]:
# join into network
metric = []
write_log(log_path, 'Recovering')
write_log(log_path, time.ctime(time.time()))
for epoch in range(epoch_num):
    for client_id, client in device_dict.items():
        client.update_client_model(num_iter=3, local_loader=loader_dict[client_id][0])
    Fed_server.store_client_historical_update(device_dict)
    Fed_server.update_central_model(frac=1, device_dict=device_dict)
    Fed_server.distribute_central_model(device_dict=device_dict)
    metric = []
    for client_id, client in device_dict.items():
        metric.append(Fed_server.validate_central_model(test_loader=loader_dict[client_id][1]))
    metric_arr=np.array(metric)
    log_txt = '{:.4f}\t{:.4f}\t{:.4f}\t{:.4f}\t'.format(np.mean(metric_arr, axis=0)[0], 
                                      np.mean(metric_arr, axis=0)[1],np.mean(metric_arr, axis=0)[2],
                                      np.mean(metric_arr, axis=0)[3])
    print('Epoch:{} mean: {}'.format(epoch ,log_txt), end='\n\n')
    write_log(log_path, log_txt)

Client 0	Epoch:  0		Loss: 0.88992149
Client 1	Epoch:  0		Loss: 0.57626021
Client 2	Epoch:  0		Loss: 0.77578562
Client 4	Epoch:  0		Loss: 1.19269216
Server - Avg_loss: 0.0187, Acc: 236.0/1799 (0.1312)
Server - Avg_loss: 0.0194, Acc: 223.0/1799 (0.1240)
Server - Avg_loss: 0.0200, Acc: 217.0/1807 (0.1201)
Server - Avg_loss: 0.0207, Acc: 209.0/1805 (0.1158)
Epoch:0 mean: 0.1228	0.0634	0.1210	0.0421	

Client 0	Epoch:  0		Loss: 0.74496192
Client 1	Epoch:  0		Loss: 0.82355976
Client 2	Epoch:  0		Loss: 1.15120542
Client 4	Epoch:  0		Loss: 1.03749311
Server - Avg_loss: 0.0056, Acc: 1036.0/1799 (0.5759)
Server - Avg_loss: 0.0053, Acc: 1090.0/1799 (0.6059)
Server - Avg_loss: 0.0052, Acc: 1069.0/1807 (0.5916)
Server - Avg_loss: 0.0050, Acc: 1157.0/1805 (0.6410)
Epoch:1 mean: 0.6036	0.5718	0.5505	0.5172	

Client 0	Epoch:  0		Loss: 0.70812660
Client 1	Epoch:  0		Loss: 0.83334744
Client 2	Epoch:  0		Loss: 0.77147925
Client 4	Epoch:  0		Loss: 0.73275983
Server - Avg_loss: 0.0057, Acc: 1028.0/1799 (0.5

Client 4	Epoch:  0		Loss: 0.70318097
Server - Avg_loss: 0.0052, Acc: 1038.0/1799 (0.5770)
Server - Avg_loss: 0.0052, Acc: 1081.0/1799 (0.6009)
Server - Avg_loss: 0.0056, Acc: 1112.0/1807 (0.6154)
Server - Avg_loss: 0.0051, Acc: 1164.0/1805 (0.6449)
Epoch:20 mean: 0.6095	0.5836	0.5556	0.5292	

Client 0	Epoch:  0		Loss: 0.96391582
Client 1	Epoch:  0		Loss: 1.08659852
Client 2	Epoch:  0		Loss: 0.63434327
Client 4	Epoch:  0		Loss: 0.76654118
Server - Avg_loss: 0.0058, Acc: 1048.0/1799 (0.5825)
Server - Avg_loss: 0.0058, Acc: 1070.0/1799 (0.5948)
Server - Avg_loss: 0.0059, Acc: 1102.0/1807 (0.6099)
Server - Avg_loss: 0.0050, Acc: 1164.0/1805 (0.6449)
Epoch:21 mean: 0.6080	0.5773	0.5494	0.5222	

Client 0	Epoch:  0		Loss: 0.67065340
Client 1	Epoch:  0		Loss: 0.69583893
Client 2	Epoch:  0		Loss: 0.49942368
Client 4	Epoch:  0		Loss: 0.64124244
Server - Avg_loss: 0.0058, Acc: 1053.0/1799 (0.5853)
Server - Avg_loss: 0.0052, Acc: 1062.0/1799 (0.5903)
Server - Avg_loss: 0.0056, Acc: 1098.0/1807 (0.

Server - Avg_loss: 0.0054, Acc: 1097.0/1807 (0.6071)
Server - Avg_loss: 0.0051, Acc: 1167.0/1805 (0.6465)
Epoch:40 mean: 0.6075	0.5867	0.5553	0.5297	

Client 0	Epoch:  0		Loss: 0.49596629
Client 1	Epoch:  0		Loss: 0.83899498
Client 2	Epoch:  0		Loss: 0.75763011
Client 4	Epoch:  0		Loss: 0.74538773
Server - Avg_loss: 0.0056, Acc: 1047.0/1799 (0.5820)
Server - Avg_loss: 0.0056, Acc: 1077.0/1799 (0.5987)
Server - Avg_loss: 0.0054, Acc: 1107.0/1807 (0.6126)
Server - Avg_loss: 0.0048, Acc: 1171.0/1805 (0.6488)
Epoch:41 mean: 0.6105	0.5816	0.5553	0.5292	

Client 0	Epoch:  0		Loss: 0.69008052
Client 1	Epoch:  0		Loss: 0.69018501
Client 2	Epoch:  0		Loss: 0.60851163
Client 4	Epoch:  0		Loss: 0.59834063
Server - Avg_loss: 0.0055, Acc: 1050.0/1799 (0.5837)
Server - Avg_loss: 0.0054, Acc: 1098.0/1799 (0.6103)
Server - Avg_loss: 0.0052, Acc: 1123.0/1807 (0.6215)
Server - Avg_loss: 0.0045, Acc: 1195.0/1805 (0.6620)
Epoch:42 mean: 0.6194	0.5836	0.5634	0.5391	

Client 0	Epoch:  0		Loss: 1.01595664
Cl

Client 0	Epoch:  0		Loss: 0.71296394
Client 1	Epoch:  0		Loss: 0.56991476
Client 2	Epoch:  0		Loss: 0.47281691
Client 4	Epoch:  0		Loss: 0.66326082
Server - Avg_loss: 0.0059, Acc: 1026.0/1799 (0.5703)
Server - Avg_loss: 0.0055, Acc: 1067.0/1799 (0.5931)
Server - Avg_loss: 0.0054, Acc: 1100.0/1807 (0.6087)
Server - Avg_loss: 0.0054, Acc: 1154.0/1805 (0.6393)
Epoch:61 mean: 0.6029	0.5718	0.5456	0.5195	

Client 0	Epoch:  0		Loss: 1.15929866
Client 1	Epoch:  0		Loss: 1.22308588
Client 2	Epoch:  0		Loss: 0.79986989
Client 4	Epoch:  0		Loss: 0.38863680
Server - Avg_loss: 0.0057, Acc: 1045.0/1799 (0.5809)
Server - Avg_loss: 0.0057, Acc: 1051.0/1799 (0.5842)
Server - Avg_loss: 0.0058, Acc: 1066.0/1807 (0.5899)
Server - Avg_loss: 0.0050, Acc: 1166.0/1805 (0.6460)
Epoch:62 mean: 0.6003	0.5698	0.5417	0.5179	

Client 0	Epoch:  0		Loss: 0.85276592
Client 1	Epoch:  0		Loss: 0.91765791
Client 2	Epoch:  0		Loss: 0.96967381
Client 4	Epoch:  0		Loss: 0.87083250
Server - Avg_loss: 0.0067, Acc: 1038.0/179

Client 4	Epoch:  0		Loss: 0.75892234
Server - Avg_loss: 0.0065, Acc: 1006.0/1799 (0.5592)
Server - Avg_loss: 0.0059, Acc: 1053.0/1799 (0.5853)
Server - Avg_loss: 0.0056, Acc: 1073.0/1807 (0.5938)
Server - Avg_loss: 0.0059, Acc: 1125.0/1805 (0.6233)
Epoch:81 mean: 0.5904	0.5730	0.5372	0.5082	

Client 0	Epoch:  0		Loss: 0.80984086
Client 1	Epoch:  0		Loss: 1.05215108
Client 2	Epoch:  0		Loss: 0.54846674
Client 4	Epoch:  0		Loss: 0.73121357
Server - Avg_loss: 0.0064, Acc: 1042.0/1799 (0.5792)
Server - Avg_loss: 0.0063, Acc: 1072.0/1799 (0.5959)
Server - Avg_loss: 0.0056, Acc: 1104.0/1807 (0.6110)
Server - Avg_loss: 0.0049, Acc: 1162.0/1805 (0.6438)
Epoch:82 mean: 0.6075	0.5697	0.5472	0.5207	

Client 0	Epoch:  0		Loss: 0.67195946
Client 1	Epoch:  0		Loss: 1.06300676
Client 2	Epoch:  0		Loss: 0.62199408
Client 4	Epoch:  0		Loss: 0.60590345
Server - Avg_loss: 0.0057, Acc: 1025.0/1799 (0.5698)
Server - Avg_loss: 0.0053, Acc: 1072.0/1799 (0.5959)
Server - Avg_loss: 0.0054, Acc: 1083.0/1807 (0.

In [1]:
repo = {1:[1]}

In [2]:
repo.get(2, []).append(3)
repo

{1: [1]}

In [3]:
repo[2]

KeyError: 2