In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Subset, DataLoader
from torchvision import datasets, transforms
import random
import time
import sys

sys.path.append("..")
from utils.tool import GPU_info, write_log, print_label_stat
from models.resnet import ResNet8, ResNet18, ResNet50
from models.mobilenet import MobileNet_S, MobileNet_M, MobileNet_L
from framework.ours import Device
from framework.DSGD import DSGD_Device
from framework.SISA import SISA_Device
from framework.FedAvgUnl import FedAvgServer, FedAvgClient


# one in ['ours', 'DSGD', 'SISA', 'Fed']
Framework = 'SISA'

# one in ['MNIST', 'FMNIST', 'CIFAR10']
DatasetName = 'FMNIST'

# one in ['resnet', 'mobilenet']
ModelType = 'resnet'

# one in [1,0]. 1 - Heterogeneous;  0 - Homogeneous
Heterogeneous = 1

In [2]:
train_set_o = None
test_set_o = None
num_classes = 10
device_num = 5
num_channel = 1
ref_size = 10000
train_test_total_size = int(50000/device_num)
test_ratio = 0.2

train_batch_size = 256 if ModelType == 'resnet' else 32
save_path = './checkpoint'
data_path = '../data'
log_path = '../log/unlearn_{}_{}_{}.txt'.format(Framework, ModelType, DatasetName)

my_seed = 1
torch.cuda.manual_seed(my_seed)
iter_num = 1000 if ModelType == 'resnet' else 600

if DatasetName == 'CIFAR10':
    train_set_o = datasets.CIFAR10(data_path, train=True, download=True)
    test_set_o = datasets.CIFAR10(data_path, train=False, download=True)
    num_channel = 3  

elif DatasetName == 'MNIST':    
    train_set_o = datasets.MNIST(data_path, train=True, download=True)
    test_set_o = datasets.MNIST(data_path, train=False, download=True)
    iter_num = 600 if ModelType == 'resnet' else 300

elif DatasetName == 'FMNIST':    
    train_set_o = datasets.FashionMNIST(data_path, train=True, download=True)
    test_set_o = datasets.FashionMNIST(data_path, train=False, download=True)

train_set_o.transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,)),
])

test_set_o.transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,)),
])

ref_set = Subset(test_set_o, range(0,int(ref_size)))
ref_loader = torch.utils.data.DataLoader(ref_set, batch_size=train_batch_size,
                                         shuffle=False, num_workers=0)

In [3]:
# data manipulation
device_dict = {}
loader_dict = {}
print('Device', end='\t')
for class_id in range(num_classes):
    print('Class'+str(class_id), end='\t')
print('SUM')
for device_id in range(device_num):
    range_start = train_test_total_size * device_id
    range_end = range_start + train_test_total_size
    # remove one class from each local dataset
    class_to_remove = torch.tensor(device_id%10)
    indices = (torch.tensor(train_set_o.targets[range_start:range_end])[..., None] !=
               class_to_remove).any(-1).nonzero(as_tuple=True)[0]
    # split train&test
    train_test_border = int((1-test_ratio)*len(indices))
    train_set = Subset(train_set_o, indices[:train_test_border]+range_start)
    test_set = Subset(train_set_o, indices[train_test_border:]+range_start)
    train_loader = DataLoader(train_set, batch_size=train_batch_size,
                              shuffle=True, num_workers=0)
    test_loader = DataLoader(test_set, batch_size=train_batch_size,
                             shuffle=True, num_workers=0) 
    loader_dict[device_id] = [train_loader, test_loader]
    print_label_stat(device_id, train_set, num_classes)

# initialize devices
for device_id in range(device_num):
    gpu_id = 0
    device_dict[device_id] = SISA_Device(device_id, gpu_id, num_classes, num_channel)

Device	Class0	Class1	Class2	Class3	Class4	Class5	Class6	Class7	Class8	Class9	SUM
D-0	0	858	809	806	762	794	806	818	792	801	7246
D-1	800	0	778	793	791	817	823	794	779	826	7201
D-2	827	799	0	787	813	806	793	825	796	748	7194
D-3	821	768	764	0	797	803	790	795	831	827	7196
D-4	785	798	839	779	0	793	787	784	824	816	7205


In [4]:
# heterogenerous scenario 
if Heterogeneous == 1:
    # heterogeneous scenario
    write_log(log_path, 'heterogeneous')
    for k, v in device_dict.items():
        v.main_model = ResNet8(num_channel) if ModelType == 'resnet' else MobileNet_S(num_channel)
        v.main_model.cuda(v.gpu_id)
        v.optimizer = optim.Adam(v.main_model.parameters(), lr=0.01)
else:
    # homogeneous scenario
    write_log(log_path, 'homogeneous')
    for k, v in device_dict.items():
        v.main_model = ResNet50(num_channel) if ModelType == 'resnet' else MobileNet_L(num_channel)
        v.main_model.cuda(v.gpu_id)
        v.optimizer = optim.Adam(v.main_model.parameters(), lr=0.01)

In [5]:
# assign neighbors
for k, v in device_dict.items():
    # full neighbor list
    neighbor_list = [x for x in device_dict.keys() if x != k]  
    v.neighbor_list = neighbor_list
    print('{}: {}'.format(k, v.neighbor_list))

0: [1, 2, 3, 4]
1: [0, 2, 3, 4]
2: [0, 1, 3, 4]
3: [0, 1, 2, 4]
4: [0, 1, 2, 3]


In [6]:
# train local main models
write_log(log_path, time.ctime(time.time()))
metric = []
for k, v in device_dict.items():
    v.train_main_model(num_iter=iter_num, local_loader=loader_dict[k][0])
    metric.append(v.validate_main_model(test_loader=loader_dict[k][1]))
print('mean: {}'.format(sum(metric)/len(metric)))


Device: 0 main model training
Iter:   0		Loss: 1.47176433
Iter:   1		Loss: 1.14225256
Iter:   2		Loss: 0.56921333
Iter:   3		Loss: 0.71758020
Iter:   4		Loss: 0.46546766
Iter:   5		Loss: 0.46458307
Iter:   6		Loss: 0.44610167
Iter:   7		Loss: 0.45232767
Iter:   8		Loss: 0.56695616
Iter:   9		Loss: 0.50574327
Iter:  10		Loss: 0.42598349
Iter:  11		Loss: 0.43091348
Iter:  12		Loss: 0.29077291
Iter:  13		Loss: 0.31339234
Iter:  14		Loss: 0.27052489
Iter:  15		Loss: 0.41307396
Iter:  16		Loss: 0.28768316
Iter:  17		Loss: 0.21994124
Iter:  18		Loss: 0.26497245
Iter:  19		Loss: 0.22555026
Iter:  20		Loss: 0.33637831
Iter:  21		Loss: 0.29247141
Iter:  22		Loss: 0.21563183
Iter:  23		Loss: 0.27726409
Iter:  24		Loss: 0.19948891
Iter:  25		Loss: 0.27536002
Iter:  26		Loss: 0.25101563
Iter:  27		Loss: 0.27837706
Iter:  28		Loss: 0.27230564
Iter:  29		Loss: 0.19685011
Iter:  30		Loss: 0.24398437
Iter:  31		Loss: 0.22734463
Iter:  32		Loss: 0.10918438
Iter:  33		Loss: 0.15319201
Iter:  34		Loss: 

Iter: 292		Loss: 0.00213352
Iter: 293		Loss: 0.00166732
Iter: 294		Loss: 0.02361766
Iter: 295		Loss: 0.03753254
Iter: 296		Loss: 0.00350319
Iter: 297		Loss: 0.00386324
Iter: 298		Loss: 0.00726851
Iter: 299		Loss: 0.02092793
Iter: 300		Loss: 0.01029490
Iter: 301		Loss: 0.08440681
Iter: 302		Loss: 0.03037160
Iter: 303		Loss: 0.14303707
Iter: 304		Loss: 0.05797297
Iter: 305		Loss: 0.05989411
Iter: 306		Loss: 0.07235216
Iter: 307		Loss: 0.06969277
Iter: 308		Loss: 0.07004535
Iter: 309		Loss: 0.03707095
Iter: 310		Loss: 0.00615101
Iter: 311		Loss: 0.00385848
Iter: 312		Loss: 0.01667565
Iter: 313		Loss: 0.02393785
Iter: 314		Loss: 0.00175611
Iter: 315		Loss: 0.00601472
Iter: 316		Loss: 0.01911912
Iter: 317		Loss: 0.01617287
Iter: 318		Loss: 0.00161140
Iter: 319		Loss: 0.00706143
Iter: 320		Loss: 0.00215832
Iter: 321		Loss: 0.00144911
Iter: 322		Loss: 0.01225213
Iter: 323		Loss: 0.00240903
Iter: 324		Loss: 0.00097274
Iter: 325		Loss: 0.00162152
Iter: 326		Loss: 0.01035473
Iter: 327		Loss: 0.0

Iter: 585		Loss: 0.00033605
Iter: 586		Loss: 0.00537730
Iter: 587		Loss: 0.00297417
Iter: 588		Loss: 0.00022863
Iter: 589		Loss: 0.00113204
Iter: 590		Loss: 0.00425747
Iter: 591		Loss: 0.00271529
Iter: 592		Loss: 0.00515367
Iter: 593		Loss: 0.00038139
Iter: 594		Loss: 0.00140241
Iter: 595		Loss: 0.00045633
Iter: 596		Loss: 0.00007648
Iter: 597		Loss: 0.00136215
Iter: 598		Loss: 0.00238861
Iter: 599		Loss: 0.00014562
Iter: 600		Loss: 0.00066759
Iter: 601		Loss: 0.00528494
Iter: 602		Loss: 0.00125471
Iter: 603		Loss: 0.00119637
Iter: 604		Loss: 0.00365798
Iter: 605		Loss: 0.00083599
Iter: 606		Loss: 0.00240719
Iter: 607		Loss: 0.00257260
Iter: 608		Loss: 0.00039255
Iter: 609		Loss: 0.00022893
Iter: 610		Loss: 0.00191916
Iter: 611		Loss: 0.00744036
Iter: 612		Loss: 0.03344749
Iter: 613		Loss: 0.22306758
Iter: 614		Loss: 0.19455098
Iter: 615		Loss: 0.09078961
Iter: 616		Loss: 0.16100575
Iter: 617		Loss: 0.13191710
Iter: 618		Loss: 0.00351174
Iter: 619		Loss: 0.00547253
Iter: 620		Loss: 0.0

Iter: 878		Loss: 0.00073825
Iter: 879		Loss: 0.01619072
Iter: 880		Loss: 0.00141537
Iter: 881		Loss: 0.00125394
Iter: 882		Loss: 0.00278633
Iter: 883		Loss: 0.01502746
Iter: 884		Loss: 0.05368323
Iter: 885		Loss: 0.01676552
Iter: 886		Loss: 0.09366321
Iter: 887		Loss: 0.00536117
Iter: 888		Loss: 0.11565167
Iter: 889		Loss: 0.05112786
Iter: 890		Loss: 0.07181067
Iter: 891		Loss: 0.01209195
Iter: 892		Loss: 0.02960336
Iter: 893		Loss: 0.20770667
Iter: 894		Loss: 0.00889447
Iter: 895		Loss: 0.00879937
Iter: 896		Loss: 0.00636554
Iter: 897		Loss: 0.00150157
Iter: 898		Loss: 0.00275279
Iter: 899		Loss: 0.00998013
Iter: 900		Loss: 0.00384719
Iter: 901		Loss: 0.00451698
Iter: 902		Loss: 0.00097827
Iter: 903		Loss: 0.00130807
Iter: 904		Loss: 0.00723534
Iter: 905		Loss: 0.00179640
Iter: 906		Loss: 0.00085560
Iter: 907		Loss: 0.00021411
Iter: 908		Loss: 0.00198988
Iter: 909		Loss: 0.00050924
Iter: 910		Loss: 0.00157208
Iter: 911		Loss: 0.00220614
Iter: 912		Loss: 0.00166655
Iter: 913		Loss: 0.0

Iter: 168		Loss: 0.18557474
Iter: 169		Loss: 0.31481123
Iter: 170		Loss: 0.18369034
Iter: 171		Loss: 0.22215272
Iter: 172		Loss: 0.10911578
Iter: 173		Loss: 0.07958680
Iter: 174		Loss: 0.05630978
Iter: 175		Loss: 0.12642038
Iter: 176		Loss: 0.08538651
Iter: 177		Loss: 0.03330253
Iter: 178		Loss: 0.30914477
Iter: 179		Loss: 0.23281530
Iter: 180		Loss: 0.02130980
Iter: 181		Loss: 0.05172744
Iter: 182		Loss: 0.15315282
Iter: 183		Loss: 0.32160214
Iter: 184		Loss: 0.25128734
Iter: 185		Loss: 0.21941872
Iter: 186		Loss: 0.26146826
Iter: 187		Loss: 0.05725821
Iter: 188		Loss: 0.19503267
Iter: 189		Loss: 0.17154337
Iter: 190		Loss: 0.08108792
Iter: 191		Loss: 0.06366056
Iter: 192		Loss: 0.11762405
Iter: 193		Loss: 0.03562138
Iter: 194		Loss: 0.02313934
Iter: 195		Loss: 0.02024085
Iter: 196		Loss: 0.14311263
Iter: 197		Loss: 0.01505277
Iter: 198		Loss: 0.08316308
Iter: 199		Loss: 0.26062351
Iter: 200		Loss: 0.24981406
Iter: 201		Loss: 0.12698008
Iter: 202		Loss: 0.07885158
Iter: 203		Loss: 0.1

Iter: 461		Loss: 0.03327189
Iter: 462		Loss: 0.48206636
Iter: 463		Loss: 0.91534942
Iter: 464		Loss: 0.15585850
Iter: 465		Loss: 0.02374568
Iter: 466		Loss: 0.07365350
Iter: 467		Loss: 0.04912788
Iter: 468		Loss: 0.05480879
Iter: 469		Loss: 0.02076114
Iter: 470		Loss: 0.05230185
Iter: 471		Loss: 0.11430024
Iter: 472		Loss: 0.30482769
Iter: 473		Loss: 0.05172682
Iter: 474		Loss: 0.09978398
Iter: 475		Loss: 0.01335185
Iter: 476		Loss: 0.03906552
Iter: 477		Loss: 0.04862627
Iter: 478		Loss: 0.01084093
Iter: 479		Loss: 0.00642076
Iter: 480		Loss: 0.01160160
Iter: 481		Loss: 0.00201731
Iter: 482		Loss: 0.17187697
Iter: 483		Loss: 0.07586449
Iter: 484		Loss: 0.18977226
Iter: 485		Loss: 0.09131860
Iter: 486		Loss: 0.02993853
Iter: 487		Loss: 0.26592910
Iter: 488		Loss: 0.02859014
Iter: 489		Loss: 0.14190382
Iter: 490		Loss: 0.03731493
Iter: 491		Loss: 0.20449606
Iter: 492		Loss: 0.09984627
Iter: 493		Loss: 0.08021961
Iter: 494		Loss: 0.00731074
Iter: 495		Loss: 0.12881726
Iter: 496		Loss: 0.0

Iter: 754		Loss: 0.01251931
Iter: 755		Loss: 0.00299535
Iter: 756		Loss: 0.28983092
Iter: 757		Loss: 0.28884861
Iter: 758		Loss: 0.43497905
Iter: 759		Loss: 0.09744004
Iter: 760		Loss: 0.02749221
Iter: 761		Loss: 0.04212619
Iter: 762		Loss: 0.25043416
Iter: 763		Loss: 0.18141238
Iter: 764		Loss: 0.00815547
Iter: 765		Loss: 0.26798263
Iter: 766		Loss: 0.01265646
Iter: 767		Loss: 0.12094461
Iter: 768		Loss: 0.00306585
Iter: 769		Loss: 0.01531945
Iter: 770		Loss: 0.01623521
Iter: 771		Loss: 0.00321441
Iter: 772		Loss: 0.02476126
Iter: 773		Loss: 0.02510468
Iter: 774		Loss: 0.00409752
Iter: 775		Loss: 0.00748533
Iter: 776		Loss: 0.01124766
Iter: 777		Loss: 0.05627738
Iter: 778		Loss: 0.03503380
Iter: 779		Loss: 0.02569262
Iter: 780		Loss: 0.05683950
Iter: 781		Loss: 0.10390797
Iter: 782		Loss: 0.00104217
Iter: 783		Loss: 0.00836693
Iter: 784		Loss: 0.00246435
Iter: 785		Loss: 0.00734211
Iter: 786		Loss: 0.01573082
Iter: 787		Loss: 0.01223242
Iter: 788		Loss: 0.06134598
Iter: 789		Loss: 0.0

Iter:  44		Loss: 0.25875157
Iter:  45		Loss: 0.24330498
Iter:  46		Loss: 0.24059342
Iter:  47		Loss: 0.18370040
Iter:  48		Loss: 0.04952816
Iter:  49		Loss: 0.14015523
Iter:  50		Loss: 0.25539219
Iter:  51		Loss: 0.33640954
Iter:  52		Loss: 0.22037043
Iter:  53		Loss: 0.28766781
Iter:  54		Loss: 0.19485083
Iter:  55		Loss: 0.10007441
Iter:  56		Loss: 0.28670958
Iter:  57		Loss: 0.14513235
Iter:  58		Loss: 0.41570088
Iter:  59		Loss: 0.14761031
Iter:  60		Loss: 0.12650715
Iter:  61		Loss: 0.30028665
Iter:  62		Loss: 0.04991804
Iter:  63		Loss: 0.10998654
Iter:  64		Loss: 0.18472759
Iter:  65		Loss: 0.06679677
Iter:  66		Loss: 0.02786852
Iter:  67		Loss: 0.10754732
Iter:  68		Loss: 0.17557806
Iter:  69		Loss: 0.17898330
Iter:  70		Loss: 0.19409254
Iter:  71		Loss: 0.11278038
Iter:  72		Loss: 0.06559278
Iter:  73		Loss: 0.26205483
Iter:  74		Loss: 0.16248560
Iter:  75		Loss: 0.07040299
Iter:  76		Loss: 0.11098348
Iter:  77		Loss: 0.20241322
Iter:  78		Loss: 0.01943192
Iter:  79		Loss: 0.1

Iter: 337		Loss: 0.04961935
Iter: 338		Loss: 0.04270820
Iter: 339		Loss: 0.00683610
Iter: 340		Loss: 0.00849083
Iter: 341		Loss: 0.03064386
Iter: 342		Loss: 0.00680215
Iter: 343		Loss: 0.00517969
Iter: 344		Loss: 0.37419865
Iter: 345		Loss: 0.01058947
Iter: 346		Loss: 0.00224637
Iter: 347		Loss: 0.00184507
Iter: 348		Loss: 0.08820352
Iter: 349		Loss: 0.05673829
Iter: 350		Loss: 0.01322597
Iter: 351		Loss: 0.08842543
Iter: 352		Loss: 0.00947629
Iter: 353		Loss: 0.10098337
Iter: 354		Loss: 0.01882740
Iter: 355		Loss: 0.01242303
Iter: 356		Loss: 0.00608447
Iter: 357		Loss: 0.00440234
Iter: 358		Loss: 0.00109889
Iter: 359		Loss: 0.00738806
Iter: 360		Loss: 0.00050612
Iter: 361		Loss: 0.01262087
Iter: 362		Loss: 0.00225374
Iter: 363		Loss: 0.00672541
Iter: 364		Loss: 0.05160413
Iter: 365		Loss: 0.08043523
Iter: 366		Loss: 0.00637401
Iter: 367		Loss: 0.00109691
Iter: 368		Loss: 0.03467782
Iter: 369		Loss: 0.01172428
Iter: 370		Loss: 0.00546468
Iter: 371		Loss: 0.00004150
Iter: 372		Loss: 0.1

Iter: 630		Loss: 0.02350063
Iter: 631		Loss: 0.01602762
Iter: 632		Loss: 0.00644197
Iter: 633		Loss: 0.61481184
Iter: 634		Loss: 0.01398565
Iter: 635		Loss: 0.31294060
Iter: 636		Loss: 0.04431345
Iter: 637		Loss: 0.03124894
Iter: 638		Loss: 0.00675322
Iter: 639		Loss: 0.01961413
Iter: 640		Loss: 0.09745451
Iter: 641		Loss: 0.10389252
Iter: 642		Loss: 0.00409782
Iter: 643		Loss: 0.02581815
Iter: 644		Loss: 0.00180807
Iter: 645		Loss: 0.03962861
Iter: 646		Loss: 0.02889544
Iter: 647		Loss: 0.01285760
Iter: 648		Loss: 0.00211714
Iter: 649		Loss: 0.64510334
Iter: 650		Loss: 0.54861331
Iter: 651		Loss: 0.17429863
Iter: 652		Loss: 0.03589146
Iter: 653		Loss: 0.27241343
Iter: 654		Loss: 0.01552705
Iter: 655		Loss: 0.04244795
Iter: 656		Loss: 0.12042528
Iter: 657		Loss: 0.06220978
Iter: 658		Loss: 0.02963937
Iter: 659		Loss: 0.22424854
Iter: 660		Loss: 0.00282007
Iter: 661		Loss: 0.32027122
Iter: 662		Loss: 0.19911966
Iter: 663		Loss: 0.05631290
Iter: 664		Loss: 0.00095112
Iter: 665		Loss: 0.0

Iter: 923		Loss: 0.00033109
Iter: 924		Loss: 0.00274808
Iter: 925		Loss: 0.03224608
Iter: 926		Loss: 0.00727698
Iter: 927		Loss: 0.02972252
Iter: 928		Loss: 0.00344654
Iter: 929		Loss: 0.00614943
Iter: 930		Loss: 0.00490532
Iter: 931		Loss: 0.02471947
Iter: 932		Loss: 0.01519725
Iter: 933		Loss: 0.47388536
Iter: 934		Loss: 0.04135223
Iter: 935		Loss: 0.08877695
Iter: 936		Loss: 0.12434855
Iter: 937		Loss: 0.06188313
Iter: 938		Loss: 0.00189467
Iter: 939		Loss: 0.09401415
Iter: 940		Loss: 0.00029930
Iter: 941		Loss: 0.01149148
Iter: 942		Loss: 0.00265220
Iter: 943		Loss: 0.01364902
Iter: 944		Loss: 0.09043623
Iter: 945		Loss: 0.01192881
Iter: 946		Loss: 0.01364487
Iter: 947		Loss: 0.08431395
Iter: 948		Loss: 0.02371362
Iter: 949		Loss: 0.00016656
Iter: 950		Loss: 0.04177447
Iter: 951		Loss: 0.08201405
Iter: 952		Loss: 0.00854212
Iter: 953		Loss: 0.12152144
Iter: 954		Loss: 0.08796943
Iter: 955		Loss: 0.04700431
Iter: 956		Loss: 0.00846680
Iter: 957		Loss: 0.22276534
Iter: 958		Loss: 0.0

Iter: 213		Loss: 0.08935793
Iter: 214		Loss: 0.15340841
Iter: 215		Loss: 0.11119853
Iter: 216		Loss: 0.07247652
Iter: 217		Loss: 0.00973786
Iter: 218		Loss: 0.00426042
Iter: 219		Loss: 0.12148707
Iter: 220		Loss: 0.19399115
Iter: 221		Loss: 0.26321974
Iter: 222		Loss: 0.02749675
Iter: 223		Loss: 0.02644573
Iter: 224		Loss: 0.04318783
Iter: 225		Loss: 0.08105167
Iter: 226		Loss: 0.10174751
Iter: 227		Loss: 0.02458201
Iter: 228		Loss: 0.02359620
Iter: 229		Loss: 0.05397614
Iter: 230		Loss: 0.32627425
Iter: 231		Loss: 0.18558812
Iter: 232		Loss: 0.10501384
Iter: 233		Loss: 0.02737770
Iter: 234		Loss: 0.38849092
Iter: 235		Loss: 0.73650122
Iter: 236		Loss: 0.10521078
Iter: 237		Loss: 0.14494860
Iter: 238		Loss: 0.01683632
Iter: 239		Loss: 0.09658729
Iter: 240		Loss: 0.04967578
Iter: 241		Loss: 0.02511924
Iter: 242		Loss: 0.07153995
Iter: 243		Loss: 0.00560979
Iter: 244		Loss: 0.03817090
Iter: 245		Loss: 0.01310658
Iter: 246		Loss: 0.34548122
Iter: 247		Loss: 0.01193628
Iter: 248		Loss: 0.0

Iter: 506		Loss: 0.15491152
Iter: 507		Loss: 0.01330042
Iter: 508		Loss: 0.02286610
Iter: 509		Loss: 0.06474965
Iter: 510		Loss: 0.08681647
Iter: 511		Loss: 0.03159748
Iter: 512		Loss: 0.00387205
Iter: 513		Loss: 0.01060218
Iter: 514		Loss: 0.01514061
Iter: 515		Loss: 0.08233105
Iter: 516		Loss: 0.03561545
Iter: 517		Loss: 0.21311462
Iter: 518		Loss: 0.16961908
Iter: 519		Loss: 0.46482676
Iter: 520		Loss: 0.01152575
Iter: 521		Loss: 0.00209322
Iter: 522		Loss: 0.02141902
Iter: 523		Loss: 0.08839136
Iter: 524		Loss: 0.11400525
Iter: 525		Loss: 0.03008642
Iter: 526		Loss: 0.06953547
Iter: 527		Loss: 0.03394132
Iter: 528		Loss: 0.02331266
Iter: 529		Loss: 0.02313275
Iter: 530		Loss: 0.00368625
Iter: 531		Loss: 0.16161874
Iter: 532		Loss: 0.00500783
Iter: 533		Loss: 0.00245031
Iter: 534		Loss: 0.00654114
Iter: 535		Loss: 0.01813857
Iter: 536		Loss: 0.01901929
Iter: 537		Loss: 0.00229153
Iter: 538		Loss: 0.05752493
Iter: 539		Loss: 0.07703783
Iter: 540		Loss: 0.00858706
Iter: 541		Loss: 0.0

Iter: 799		Loss: 0.28270334
Iter: 800		Loss: 0.01780733
Iter: 801		Loss: 0.00024531
Iter: 802		Loss: 0.08926450
Iter: 803		Loss: 0.01163832
Iter: 804		Loss: 0.01531539
Iter: 805		Loss: 0.04197205
Iter: 806		Loss: 0.00174270
Iter: 807		Loss: 0.04679480
Iter: 808		Loss: 0.10314353
Iter: 809		Loss: 0.00739004
Iter: 810		Loss: 0.00398294
Iter: 811		Loss: 0.00724046
Iter: 812		Loss: 0.00826332
Iter: 813		Loss: 0.01423145
Iter: 814		Loss: 0.07220329
Iter: 815		Loss: 0.05368099
Iter: 816		Loss: 0.35537794
Iter: 817		Loss: 0.11907478
Iter: 818		Loss: 0.00509053
Iter: 819		Loss: 0.02366089
Iter: 820		Loss: 0.17498459
Iter: 821		Loss: 0.56533659
Iter: 822		Loss: 0.01690430
Iter: 823		Loss: 0.18261883
Iter: 824		Loss: 0.01520689
Iter: 825		Loss: 0.00923001
Iter: 826		Loss: 0.01228157
Iter: 827		Loss: 0.00201397
Iter: 828		Loss: 0.01363439
Iter: 829		Loss: 0.03677495
Iter: 830		Loss: 0.00530449
Iter: 831		Loss: 0.00153565
Iter: 832		Loss: 0.09261917
Iter: 833		Loss: 0.13158897
Iter: 834		Loss: 0.2

Iter:  89		Loss: 0.08507440
Iter:  90		Loss: 0.12376337
Iter:  91		Loss: 0.03594841
Iter:  92		Loss: 0.02963701
Iter:  93		Loss: 0.04955677
Iter:  94		Loss: 0.04652553
Iter:  95		Loss: 0.09341417
Iter:  96		Loss: 0.09927426
Iter:  97		Loss: 0.02431631
Iter:  98		Loss: 0.22931264
Iter:  99		Loss: 0.34181604
Iter: 100		Loss: 0.20464021
Iter: 101		Loss: 0.02346400
Iter: 102		Loss: 0.04380174
Iter: 103		Loss: 0.06074838
Iter: 104		Loss: 0.07115678
Iter: 105		Loss: 0.03405620
Iter: 106		Loss: 0.15113075
Iter: 107		Loss: 0.43117160
Iter: 108		Loss: 0.26734537
Iter: 109		Loss: 0.05915751
Iter: 110		Loss: 0.03386147
Iter: 111		Loss: 0.03646589
Iter: 112		Loss: 0.11030866
Iter: 113		Loss: 0.03369037
Iter: 114		Loss: 0.08950436
Iter: 115		Loss: 0.13661040
Iter: 116		Loss: 0.18990819
Iter: 117		Loss: 0.18376108
Iter: 118		Loss: 0.09185717
Iter: 119		Loss: 0.16065629
Iter: 120		Loss: 0.06404770
Iter: 121		Loss: 0.11341397
Iter: 122		Loss: 0.07551720
Iter: 123		Loss: 0.03787322
Iter: 124		Loss: 0.0

Iter: 382		Loss: 0.62590092
Iter: 383		Loss: 0.29621267
Iter: 384		Loss: 0.02004344
Iter: 385		Loss: 0.00970666
Iter: 386		Loss: 0.00463775
Iter: 387		Loss: 0.00380228
Iter: 388		Loss: 0.00464958
Iter: 389		Loss: 0.03224893
Iter: 390		Loss: 0.02624192
Iter: 391		Loss: 0.05866015
Iter: 392		Loss: 0.02777232
Iter: 393		Loss: 0.00845616
Iter: 394		Loss: 0.00719907
Iter: 395		Loss: 0.00845369
Iter: 396		Loss: 0.00143865
Iter: 397		Loss: 0.01924340
Iter: 398		Loss: 0.00047053
Iter: 399		Loss: 0.03832282
Iter: 400		Loss: 0.01827823
Iter: 401		Loss: 0.17267327
Iter: 402		Loss: 0.07822470
Iter: 403		Loss: 0.00575845
Iter: 404		Loss: 0.04638180
Iter: 405		Loss: 0.00064262
Iter: 406		Loss: 0.04660369
Iter: 407		Loss: 0.00986796
Iter: 408		Loss: 0.14522576
Iter: 409		Loss: 0.00684403
Iter: 410		Loss: 0.00731216
Iter: 411		Loss: 0.00992843
Iter: 412		Loss: 0.00842374
Iter: 413		Loss: 0.02792916
Iter: 414		Loss: 0.01505645
Iter: 415		Loss: 0.00609288
Iter: 416		Loss: 0.00084465
Iter: 417		Loss: 0.0

Iter: 675		Loss: 0.00529066
Iter: 676		Loss: 0.01592434
Iter: 677		Loss: 0.00915654
Iter: 678		Loss: 0.01958857
Iter: 679		Loss: 0.09050428
Iter: 680		Loss: 0.08449885
Iter: 681		Loss: 0.36984468
Iter: 682		Loss: 0.04948881
Iter: 683		Loss: 0.00827477
Iter: 684		Loss: 0.00256573
Iter: 685		Loss: 0.01007948
Iter: 686		Loss: 0.00162114
Iter: 687		Loss: 0.00488595
Iter: 688		Loss: 0.00157196
Iter: 689		Loss: 0.00349436
Iter: 690		Loss: 0.00017032
Iter: 691		Loss: 0.00169694
Iter: 692		Loss: 0.00166979
Iter: 693		Loss: 0.00556238
Iter: 694		Loss: 0.00079094
Iter: 695		Loss: 0.00419313
Iter: 696		Loss: 0.00030406
Iter: 697		Loss: 0.00023519
Iter: 698		Loss: 0.00764434
Iter: 699		Loss: 0.01199200
Iter: 700		Loss: 0.01172965
Iter: 701		Loss: 0.00061571
Iter: 702		Loss: 0.11733308
Iter: 703		Loss: 0.00668137
Iter: 704		Loss: 0.00351684
Iter: 705		Loss: 0.00732637
Iter: 706		Loss: 0.04619568
Iter: 707		Loss: 0.04398888
Iter: 708		Loss: 0.11069316
Iter: 709		Loss: 0.06307622
Iter: 710		Loss: 0.0

Iter: 968		Loss: 0.00735929
Iter: 969		Loss: 0.00569549
Iter: 970		Loss: 0.00109700
Iter: 971		Loss: 0.01727061
Iter: 972		Loss: 0.01038295
Iter: 973		Loss: 0.00266318
Iter: 974		Loss: 0.00047623
Iter: 975		Loss: 0.00152507
Iter: 976		Loss: 0.00033162
Iter: 977		Loss: 0.01184409
Iter: 978		Loss: 0.00832379
Iter: 979		Loss: 0.00219966
Iter: 980		Loss: 0.00224609
Iter: 981		Loss: 0.00003383
Iter: 982		Loss: 0.00221088
Iter: 983		Loss: 0.00086306
Iter: 984		Loss: 0.00109378
Iter: 985		Loss: 0.00116549
Iter: 986		Loss: 0.04783385
Iter: 987		Loss: 0.00339830
Iter: 988		Loss: 0.00113259
Iter: 989		Loss: 0.05575734
Iter: 990		Loss: 0.26412231
Iter: 991		Loss: 0.00259516
Iter: 992		Loss: 0.00140822
Iter: 993		Loss: 0.00822617
Iter: 994		Loss: 0.02358696
Iter: 995		Loss: 0.00717285
Iter: 996		Loss: 0.04617677
Iter: 997		Loss: 0.00356740
Iter: 998		Loss: 0.01622532
Iter: 999		Loss: 0.15640442
Device: 4 Val_main - Avg_loss: 1.0657, Acc: 1594.0/1802 (0.8846)
mean: 0.8534051179885864


In [7]:
# assign neighbors
for k, v in device_dict.items():
    # full neighbor list
    neighbor_list = [x for x in device_dict.keys() if x != k]
    v.neighbor_list = neighbor_list
    
for k, v in device_dict.items():
    print('D_{}: {}'.format(k, v.neighbor_list))

D_0: [1, 2, 3, 4]
D_1: [0, 2, 3, 4]
D_2: [0, 1, 3, 4]
D_3: [0, 1, 2, 4]
D_4: [0, 1, 2, 3]


In [8]:
# overall performance
write_log(log_path, time.ctime(time.time())+' SMLLL')
write_log(log_path, 'AvgAcc\tAvgPre\tAvgRec\tAvgf1')
metric = []
for k, v in device_dict.items():
    metric.append(v.validate_ensamble(test_loader=loader_dict[k][1], device_dict=device_dict))
metric_arr=np.array(metric)
log_txt = '{:.4f}\t{:.4f}\t{:.4f}\t{:.4f}\t'.format(np.mean(metric_arr, axis=0)[0], 
                                                    np.mean(metric_arr, axis=0)[1],
                                                    np.mean(metric_arr, axis=0)[2],
                                                    np.mean(metric_arr, axis=0)[3])
print('mean: '+ log_txt)
write_log(log_path, log_txt)

SISA: Client 0 Test -  Accuracy: 1576.0/1812 (0.8698)
SISA: Client 1 Test -  Accuracy: 1483.0/1801 (0.8234)
SISA: Client 2 Test -  Accuracy: 1558.0/1799 (0.8660)
SISA: Client 3 Test -  Accuracy: 1527.0/1799 (0.8488)
SISA: Client 4 Test -  Accuracy: 1588.0/1802 (0.8812)
mean: 0.8579	0.8354	0.7993	0.8027	


In [9]:
# assign neighbors
for k, v in device_dict.items():
    neighbor_list = [x for x in device_dict.keys() if x != k]
    neighbor_list = neighbor_list[:-1]
    if k == 4:
        neighbor_list = []
    v.neighbor_list = neighbor_list
    
for k, v in device_dict.items():
    print('D_{}: {}'.format(k, v.neighbor_list))

D_0: [1, 2, 3]
D_1: [0, 2, 3]
D_2: [0, 1, 3]
D_3: [0, 1, 2]
D_4: []


In [10]:
# overall performance
write_log(log_path, time.ctime(time.time())+' SMLL')
write_log(log_path, 'AvgAcc\tAvgPre\tAvgRec\tAvgf1')
metric = []
for k, v in device_dict.items():
    if k != 4:
        metric.append(v.validate_ensamble(test_loader=loader_dict[k][1], device_dict=device_dict))
metric_arr=np.array(metric)
log_txt = '{:.4f}\t{:.4f}\t{:.4f}\t{:.4f}\t'.format(np.mean(metric_arr, axis=0)[0], 
                                                    np.mean(metric_arr, axis=0)[1],
                                                    np.mean(metric_arr, axis=0)[2],
                                                    np.mean(metric_arr, axis=0)[3])
print('mean: '+ log_txt)
write_log(log_path, log_txt)

SISA: Client 0 Test -  Accuracy: 1585.0/1812 (0.8747)
SISA: Client 1 Test -  Accuracy: 1472.0/1801 (0.8173)
SISA: Client 2 Test -  Accuracy: 1549.0/1799 (0.8610)
SISA: Client 3 Test -  Accuracy: 1510.0/1799 (0.8394)
mean: 0.8481	0.8375	0.8018	0.8007	


In [11]:
# assign neighbors
for k, v in device_dict.items():
    neighbor_list = [x for x in device_dict.keys() if x != k]
    neighbor_list = [x for x in neighbor_list if x not in [3,4]]
    if k == 4:
        neighbor_list = []
    if k == 3:
        neighbor_list = []
    v.neighbor_list = neighbor_list
    
for k, v in device_dict.items():
    print('D_{}: {}'.format(k, v.neighbor_list))

D_0: [1, 2]
D_1: [0, 2]
D_2: [0, 1]
D_3: []
D_4: []


In [12]:
# overall performance
write_log(log_path, time.ctime(time.time())+' SML')
write_log(log_path, 'AvgAcc\tAvgPre\tAvgRec\tAvgf1')
metric = []
for k, v in device_dict.items():
    if k not in [3,4]:
        metric.append(v.validate_ensamble(test_loader=loader_dict[k][1], device_dict=device_dict))
metric_arr=np.array(metric)
log_txt = '{:.4f}\t{:.4f}\t{:.4f}\t{:.4f}\t'.format(np.mean(metric_arr, axis=0)[0], 
                                                    np.mean(metric_arr, axis=0)[1],
                                                    np.mean(metric_arr, axis=0)[2],
                                                    np.mean(metric_arr, axis=0)[3])
print('mean: '+ log_txt)
write_log(log_path, log_txt)

SISA: Client 0 Test -  Accuracy: 1542.0/1812 (0.8510)
SISA: Client 1 Test -  Accuracy: 1427.0/1801 (0.7923)
SISA: Client 2 Test -  Accuracy: 1520.0/1799 (0.8449)
mean: 0.8294	0.8411	0.7939	0.7875	


In [13]:
# assign neighbors
for k, v in device_dict.items():
    neighbor_list = [x for x in device_dict.keys() if x != k]
    neighbor_list = [x for x in neighbor_list if x not in [3]]
    if k == 3:
        neighbor_list = []
    v.neighbor_list = neighbor_list
    
for k, v in device_dict.items():
    print('D_{}: {}'.format(k, v.neighbor_list))

D_0: [1, 2, 4]
D_1: [0, 2, 4]
D_2: [0, 1, 4]
D_3: []
D_4: [0, 1, 2]


In [14]:
# overall performance
write_log(log_path, time.ctime(time.time())+' SMLL')
write_log(log_path, 'AvgAcc\tAvgPre\tAvgRec\tAvgf1')
metric = []
for k, v in device_dict.items():
    if k not in [3]:
        metric.append(v.validate_ensamble(test_loader=loader_dict[k][1], device_dict=device_dict))
metric_arr=np.array(metric)
log_txt = '{:.4f}\t{:.4f}\t{:.4f}\t{:.4f}\t'.format(np.mean(metric_arr, axis=0)[0], 
                                                    np.mean(metric_arr, axis=0)[1],
                                                    np.mean(metric_arr, axis=0)[2],
                                                    np.mean(metric_arr, axis=0)[3])
print('mean: '+ log_txt)
write_log(log_path, log_txt)

SISA: Client 0 Test -  Accuracy: 1533.0/1812 (0.8460)
SISA: Client 1 Test -  Accuracy: 1437.0/1801 (0.7979)
SISA: Client 2 Test -  Accuracy: 1513.0/1799 (0.8410)
SISA: Client 4 Test -  Accuracy: 1563.0/1802 (0.8674)
mean: 0.8381	0.8389	0.7954	0.7937	
