In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Subset, DataLoader
from torchvision import datasets, transforms
import random
import time
import sys

sys.path.append("..")
from utils.tool import GPU_info, write_log, print_label_stat
from models.resnet import ResNet8, ResNet18, ResNet50
from models.mobilenet import MobileNet_S, MobileNet_M, MobileNet_L
from framework.ours import Device
from framework.DSGD import DSGD_Device
from framework.SISA import SISA_Device
from framework.FedAvgUnl import FedAvgServer, FedAvgClient


# one in ['ours', 'DSGD', 'SISA', 'Fed']
Framework = 'SISA'

# one in ['MNIST', 'FMNIST', 'CIFAR10']
DatasetName = 'MNIST'

# one in ['resnet', 'mobilenet']
ModelType = 'mobilenet'

# one in [1,0]. 1 - Heterogeneous;  0 - Homogeneous
Heterogeneous = 1

In [2]:
train_set_o = None
test_set_o = None
num_classes = 10
device_num = 6
num_channel = 1
ref_size = 10000
train_test_total_size = int(60000/device_num)
test_ratio = 0.2
CIFAR10_segmentation = 0

train_batch_size = 256 if ModelType == 'resnet' else 32
save_path = './checkpoint'
data_path = '../data'
log_path = '../log/{}_{}_{}.txt'.format(Framework, ModelType, DatasetName)

my_seed = 1
torch.cuda.manual_seed(my_seed)
iter_num = 100

if DatasetName == 'CIFAR10':
    train_set_o = datasets.CIFAR10(data_path, train=True, download=True)
    test_set_o = datasets.CIFAR10(data_path, train=False, download=True)
    device_num = 5
    num_channel = 3
    CIFAR10_segmentation = 1
    train_test_total_size = int(50000/device_num)
    iter_num = 1000 if ModelType == 'resnet' else 900
    

elif DatasetName == 'MNIST':    
    train_set_o = datasets.MNIST(data_path, train=True, download=True)
    test_set_o = datasets.MNIST(data_path, train=False, download=True)
    iter_num = 1000 if ModelType == 'resnet' else 900

elif DatasetName == 'FMNIST':    
    train_set_o = datasets.FashionMNIST(data_path, train=True, download=True)
    test_set_o = datasets.FashionMNIST(data_path, train=False, download=True)
    epoch_num = 1000 if ModelType == 'resnet' else 900

train_set_o.transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,)),
])

test_set_o.transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,)),
])

ref_set = Subset(test_set_o, range(0,int(ref_size)))
ref_loader = torch.utils.data.DataLoader(ref_set, batch_size=train_batch_size,
                                         shuffle=False, num_workers=0)

In [3]:
# data manipulation
device_dict = {}
loader_dict = {}
print('Device', end='\t')
for class_id in range(num_classes):
    print('Class'+str(class_id), end='\t')
print('SUM')

for device_id in range(device_num):
    range_start = train_test_total_size * device_id
    range_end = range_start + train_test_total_size
    
    # remove one class from each local dataset
    class_to_remove = torch.tensor(device_id%10)
    indices = (torch.tensor(train_set_o.targets[range_start:range_end])[..., None] !=
               class_to_remove).any(-1).nonzero(as_tuple=True)[0]
    
    # split train&test
    train_test_border = int((1-test_ratio)*len(indices))
    train_set = Subset(train_set_o, indices[:train_test_border]+range_start)
    test_set = Subset(train_set_o, indices[train_test_border:]+range_start)
    train_loader = DataLoader(train_set, batch_size=train_batch_size,
                              shuffle=True, num_workers=0)
    test_loader = DataLoader(test_set, batch_size=train_batch_size,
                             shuffle=True, num_workers=0)
    
    loader_dict[device_id] = [train_loader, test_loader]
    print_label_stat(device_id, train_set, num_classes)
    
# initialize devices
for device_id in range(device_num):
    gpu_id = 0
    device_dict[device_id] = SISA_Device(device_id, gpu_id, num_classes, num_channel)
    
# assign neighbors
for k, v in device_dict.items():
    # full neighbor list
    neighbor_list = [x for x in device_dict.keys() if x != k]  
    v.neighbor_list = neighbor_list
    print('{}: {}'.format(k, v.neighbor_list))

Device	Class0	Class1	Class2	Class3	Class4	Class5	Class6	Class7	Class8	Class9	SUM
D-0	0	905	788	811	798	701	797	852	752	795	7199
D-1	782	0	749	851	767	724	757	830	788	828	7076
D-2	757	919	0	819	781	751	803	813	759	782	7184
D-3	758	905	798	0	783	710	794	819	798	828	7193
D-4	804	901	817	807	0	719	773	842	792	785	7240
D-5	797	850	790	821	787	0	765	877	819	762	7268
0: [1, 2, 3, 4, 5]
1: [0, 2, 3, 4, 5]
2: [0, 1, 3, 4, 5]
3: [0, 1, 2, 4, 5]
4: [0, 1, 2, 3, 5]
5: [0, 1, 2, 3, 4]


In [4]:
# heterogenerous scenario 
if Heterogeneous == 1:
    # heterogeneous scenario
    write_log(log_path, 'heterogeneous')
    for k, v in device_dict.items():
        v.main_model = ResNet8(num_channel) if ModelType == 'resnet' else MobileNet_S(num_channel)
        v.main_model.cuda(v.gpu_id)
        v.optimizer = optim.Adam(v.main_model.parameters(), lr=0.01)
else:
    # homogeneous scenario
    write_log(log_path, 'homogeneous')
    for k, v in device_dict.items():
        v.main_model = ResNet50(num_channel) if ModelType == 'resnet' else MobileNet_L(num_channel)
        v.main_model.cuda(v.gpu_id)
        v.optimizer = optim.Adam(v.main_model.parameters(), lr=0.01)

In [10]:
# train local main models
write_log(log_path, time.ctime(time.time()))
write_log(log_path, 'AvgAcc\tAvgPre\tAvgRec\tAvgf1')
observe_gap = 10
for epoch in range(iter_num//observe_gap):
    print('epoch: '+ str(epoch))
    metric = []
    for k, v in device_dict.items():
        v.train_main_model(num_iter=observe_gap, local_loader=loader_dict[k][0])
        metric.append(v.validate_ensamble(test_loader=loader_dict[k][1], device_dict=device_dict))
    metric_arr=np.array(metric)
    log_txt = '{:.4f}\t{:.4f}\t{:.4f}\t{:.4f}\t'.format(np.mean(metric_arr, axis=0)[0], 
                                                        np.mean(metric_arr, axis=0)[1],
                                                        np.mean(metric_arr, axis=0)[2],
                                                        np.mean(metric_arr, axis=0)[3])
    print('mean: '+ log_txt)
    write_log(log_path, log_txt)

epoch: 0

Device: 0 main model training
Iter:   0		Loss: 0.00001620
Iter:   1		Loss: 0.00088802
Iter:   2		Loss: 0.00006108
Iter:   3		Loss: 0.00403624
Iter:   4		Loss: 0.00039813
Iter:   5		Loss: 0.00017781
Iter:   6		Loss: 0.00002797
Iter:   7		Loss: 0.00127497
Iter:   8		Loss: 0.00023705
Iter:   9		Loss: 0.00035186
SISA: Client 0 Test -  Accuracy: 1755.0/1800 (0.9750)

Device: 1 main model training
Iter:   0		Loss: 0.45545802
Iter:   1		Loss: 1.65545988
Iter:   2		Loss: 0.00032588
Iter:   3		Loss: 0.00023133
Iter:   4		Loss: 0.02839174
Iter:   5		Loss: 0.00000286
Iter:   6		Loss: 0.03431626
Iter:   7		Loss: 0.00347886
Iter:   8		Loss: 0.00002512
Iter:   9		Loss: 0.00000113
SISA: Client 1 Test -  Accuracy: 1735.0/1770 (0.9802)

Device: 2 main model training
Iter:   0		Loss: 0.00015931
Iter:   1		Loss: 0.00008731
Iter:   2		Loss: 0.00001720
Iter:   3		Loss: 0.32310006
Iter:   4		Loss: 0.04062113
Iter:   5		Loss: 0.00001661
Iter:   6		Loss: 0.00005796
Iter:   7		Loss: 0.00293080
Iter: 

Iter:   0		Loss: 0.00011167
Iter:   1		Loss: 0.01040037
Iter:   2		Loss: 0.00016698
Iter:   3		Loss: 0.35528487
Iter:   4		Loss: 0.00005895
Iter:   5		Loss: 0.00000143
Iter:   6		Loss: 0.00000027
Iter:   7		Loss: 0.00002143
Iter:   8		Loss: 0.00012800
Iter:   9		Loss: 0.00073570
SISA: Client 4 Test -  Accuracy: 1769.0/1810 (0.9773)

Device: 5 main model training
Iter:   0		Loss: 0.05299761
Iter:   1		Loss: 0.15155637
Iter:   2		Loss: 0.00591377
Iter:   3		Loss: 0.00374040
Iter:   4		Loss: 0.00000077
Iter:   5		Loss: 0.95722693
Iter:   6		Loss: 0.00305764
Iter:   7		Loss: 0.08616051
Iter:   8		Loss: 0.00032517
Iter:   9		Loss: 0.00051327
SISA: Client 5 Test -  Accuracy: 1791.0/1817 (0.9857)
mean: 0.9776	0.9759	0.9737	0.9708	
epoch: 4

Device: 0 main model training
Iter:   0		Loss: 0.00014628
Iter:   1		Loss: 0.01570264
Iter:   2		Loss: 0.00346394
Iter:   3		Loss: 0.00005076
Iter:   4		Loss: 0.00003754
Iter:   5		Loss: 0.00334847
Iter:   6		Loss: 0.00029199
Iter:   7		Loss: 0.00388781
It

Iter:   0		Loss: 0.00003996
Iter:   1		Loss: 0.00000530
Iter:   2		Loss: 0.00003215
Iter:   3		Loss: 0.00380732
Iter:   4		Loss: 0.01866508
Iter:   5		Loss: 0.00000111
Iter:   6		Loss: 0.00000209
Iter:   7		Loss: 0.00849900
Iter:   8		Loss: 0.00009115
Iter:   9		Loss: 0.00328661
SISA: Client 2 Test -  Accuracy: 1741.0/1797 (0.9688)

Device: 3 main model training
Iter:   0		Loss: 0.00001878
Iter:   1		Loss: 0.16252981
Iter:   2		Loss: 0.00012388
Iter:   3		Loss: 0.00004775
Iter:   4		Loss: 0.00026921
Iter:   5		Loss: 0.00710080
Iter:   6		Loss: 0.00627779
Iter:   7		Loss: 0.00006691
Iter:   8		Loss: 0.00205190
Iter:   9		Loss: 0.00064697
SISA: Client 3 Test -  Accuracy: 1742.0/1799 (0.9683)

Device: 4 main model training
Iter:   0		Loss: 0.00001276
Iter:   1		Loss: 0.00001351
Iter:   2		Loss: 0.00008015
Iter:   3		Loss: 0.04089452
Iter:   4		Loss: 0.01172743
Iter:   5		Loss: 0.00001264
Iter:   6		Loss: 0.00038734
Iter:   7		Loss: 0.00195342
Iter:   8		Loss: 0.00481934
Iter:   9		Loss: 0

Iter:   0		Loss: 0.00000478
Iter:   1		Loss: 0.00001760
Iter:   2		Loss: 0.00001794
Iter:   3		Loss: 0.00004267
Iter:   4		Loss: 0.00007105
Iter:   5		Loss: 0.00020199
Iter:   6		Loss: 0.00001997
Iter:   7		Loss: 0.00002916
Iter:   8		Loss: 0.00000929
Iter:   9		Loss: 0.00005774
SISA: Client 0 Test -  Accuracy: 1760.0/1800 (0.9778)

Device: 1 main model training
Iter:   0		Loss: 0.26646978
Iter:   1		Loss: 0.03533902
Iter:   2		Loss: 0.00222684
Iter:   3		Loss: 0.00065283
Iter:   4		Loss: 0.03040208
Iter:   5		Loss: 0.00001991
Iter:   6		Loss: 0.00024510
Iter:   7		Loss: 0.00000018
Iter:   8		Loss: 0.25199267
Iter:   9		Loss: 0.00124915
SISA: Client 1 Test -  Accuracy: 1733.0/1770 (0.9791)

Device: 2 main model training
Iter:   0		Loss: 0.00014399
Iter:   1		Loss: 0.00020253
Iter:   2		Loss: 0.00010359
Iter:   3		Loss: 0.00003117
Iter:   4		Loss: 0.00020300
Iter:   5		Loss: 0.00001538
Iter:   6		Loss: 0.00001629
Iter:   7		Loss: 0.00002331
Iter:   8		Loss: 0.00002336
Iter:   9		Loss: 0

Iter:   1		Loss: 0.00032121
Iter:   2		Loss: 0.09787635
Iter:   3		Loss: 0.00217856
Iter:   4		Loss: 0.00042382
Iter:   5		Loss: 0.06129724
Iter:   6		Loss: 0.00088357
Iter:   7		Loss: 0.02635096
Iter:   8		Loss: 0.00024381
Iter:   9		Loss: 0.00007100
SISA: Client 4 Test -  Accuracy: 1761.0/1810 (0.9729)

Device: 5 main model training
Iter:   0		Loss: 0.00000688
Iter:   1		Loss: 0.13778800
Iter:   2		Loss: 0.00011665
Iter:   3		Loss: 0.01012409
Iter:   4		Loss: 0.00004544
Iter:   5		Loss: 0.22058636
Iter:   6		Loss: 0.00128625
Iter:   7		Loss: 0.00066104
Iter:   8		Loss: 0.00216873
Iter:   9		Loss: 0.01636566
SISA: Client 5 Test -  Accuracy: 1795.0/1817 (0.9879)
mean: 0.9760	0.9736	0.9723	0.9688	
epoch: 15

Device: 0 main model training
Iter:   0		Loss: 0.00012956
Iter:   1		Loss: 0.00015193
Iter:   2		Loss: 0.00002156
Iter:   3		Loss: 0.00004429
Iter:   4		Loss: 0.00495744
Iter:   5		Loss: 0.00110897
Iter:   6		Loss: 0.00012009
Iter:   7		Loss: 0.00035220
Iter:   8		Loss: 0.00030524
I

Iter:   1		Loss: 0.00270113
Iter:   2		Loss: 0.00524925
Iter:   3		Loss: 0.00011670
Iter:   4		Loss: 0.00001733
Iter:   5		Loss: 0.07609628
Iter:   6		Loss: 0.00256495
Iter:   7		Loss: 0.00027569
Iter:   8		Loss: 0.00051136
Iter:   9		Loss: 0.00055162
SISA: Client 2 Test -  Accuracy: 1761.0/1797 (0.9800)

Device: 3 main model training
Iter:   0		Loss: 0.00006325
Iter:   1		Loss: 0.01656859
Iter:   2		Loss: 0.00002156
Iter:   3		Loss: 0.00112563
Iter:   4		Loss: 0.00000399
Iter:   5		Loss: 0.00347866
Iter:   6		Loss: 0.01462674
Iter:   7		Loss: 0.00004860
Iter:   8		Loss: 0.00013662
Iter:   9		Loss: 0.00290873
SISA: Client 3 Test -  Accuracy: 1748.0/1799 (0.9717)

Device: 4 main model training
Iter:   0		Loss: 0.42211074
Iter:   1		Loss: 0.01058948
Iter:   2		Loss: 0.00002274
Iter:   3		Loss: 0.00000006
Iter:   4		Loss: 0.02650815
Iter:   5		Loss: 0.00017971
Iter:   6		Loss: 0.00006482
Iter:   7		Loss: 0.00130539
Iter:   8		Loss: 0.61847186
Iter:   9		Loss: 0.00000396
SISA: Client 4 Tes

Iter:   1		Loss: 0.01606199
Iter:   2		Loss: 0.00007721
Iter:   3		Loss: 0.00001882
Iter:   4		Loss: 0.00021762
Iter:   5		Loss: 0.00703124
Iter:   6		Loss: 0.00001516
Iter:   7		Loss: 0.00112577
Iter:   8		Loss: 0.00614511
Iter:   9		Loss: 0.00057416
SISA: Client 0 Test -  Accuracy: 1760.0/1800 (0.9778)

Device: 1 main model training
Iter:   0		Loss: 0.00014042
Iter:   1		Loss: 0.00013905
Iter:   2		Loss: 0.00000009
Iter:   3		Loss: 0.01292852
Iter:   4		Loss: 0.01134157
Iter:   5		Loss: 0.00232388
Iter:   6		Loss: 1.08265579
Iter:   7		Loss: 0.01384871
Iter:   8		Loss: 0.00009234
Iter:   9		Loss: 0.00005876
SISA: Client 1 Test -  Accuracy: 1737.0/1770 (0.9814)

Device: 2 main model training
Iter:   0		Loss: 0.07378308
Iter:   1		Loss: 0.00226585
Iter:   2		Loss: 0.00066718
Iter:   3		Loss: 0.01515581
Iter:   4		Loss: 0.00011079
Iter:   5		Loss: 0.01470479
Iter:   6		Loss: 0.00000089
Iter:   7		Loss: 0.00000075
Iter:   8		Loss: 0.00337260
Iter:   9		Loss: 0.04825590
SISA: Client 2 Tes

Iter:   2		Loss: 0.00030980
Iter:   3		Loss: 0.00000040
Iter:   4		Loss: 0.16628635
Iter:   5		Loss: 0.17820755
Iter:   6		Loss: 0.00047983
Iter:   7		Loss: 0.00015587
Iter:   8		Loss: 0.00110626
Iter:   9		Loss: 0.00843094
SISA: Client 4 Test -  Accuracy: 1770.0/1810 (0.9779)

Device: 5 main model training
Iter:   0		Loss: 0.00002792
Iter:   1		Loss: 0.00128509
Iter:   2		Loss: 0.00006928
Iter:   3		Loss: 0.00023738
Iter:   4		Loss: 0.00042251
Iter:   5		Loss: 0.00000069
Iter:   6		Loss: 0.00001287
Iter:   7		Loss: 1.07015276
Iter:   8		Loss: 0.03521837
Iter:   9		Loss: 0.00008153
SISA: Client 5 Test -  Accuracy: 1793.0/1817 (0.9868)
mean: 0.9756	0.9722	0.9718	0.9681	
epoch: 26

Device: 0 main model training
Iter:   0		Loss: 0.00231741
Iter:   1		Loss: 0.00001263
Iter:   2		Loss: 0.00046264
Iter:   3		Loss: 0.00003481
Iter:   4		Loss: 0.00004434
Iter:   5		Loss: 0.00025356
Iter:   6		Loss: 0.00026677
Iter:   7		Loss: 0.00003049
Iter:   8		Loss: 0.00002100
Iter:   9		Loss: 0.00020820
S

Iter:   2		Loss: 0.00019797
Iter:   3		Loss: 0.04739770
Iter:   4		Loss: 0.00000526
Iter:   5		Loss: 0.00177667
Iter:   6		Loss: 0.00027838
Iter:   7		Loss: 0.08126126
Iter:   8		Loss: 0.00000051
Iter:   9		Loss: 0.00001328
SISA: Client 2 Test -  Accuracy: 1748.0/1797 (0.9727)

Device: 3 main model training
Iter:   0		Loss: 0.00008012
Iter:   1		Loss: 0.00129111
Iter:   2		Loss: 0.00001827
Iter:   3		Loss: 0.00003971
Iter:   4		Loss: 0.00006039
Iter:   5		Loss: 0.00000125
Iter:   6		Loss: 0.00006848
Iter:   7		Loss: 0.00000699
Iter:   8		Loss: 0.00002467
Iter:   9		Loss: 0.00004811
SISA: Client 3 Test -  Accuracy: 1743.0/1799 (0.9689)

Device: 4 main model training
Iter:   0		Loss: 0.00005702
Iter:   1		Loss: 0.01251558
Iter:   2		Loss: 0.00067856
Iter:   3		Loss: 0.00007077
Iter:   4		Loss: 0.00234523
Iter:   5		Loss: 0.00000039
Iter:   6		Loss: 0.00002871
Iter:   7		Loss: 0.02834191
Iter:   8		Loss: 0.00005172
Iter:   9		Loss: 0.00004142
SISA: Client 4 Test -  Accuracy: 1765.0/1810 (

In [9]:
iter_num

300