In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Subset, DataLoader
from torchvision import datasets, transforms
import random
import time
import sys

sys.path.append("..")
from utils.tool import GPU_info, write_log, print_label_stat
from models.resnet import ResNet8, ResNet18, ResNet50
from models.mobilenet import MobileNet_S, MobileNet_M, MobileNet_L
from framework.ours import Device
from framework.DSGD import DSGD_Device
from framework.SISA import SISA_Device
from framework.FedAvgUnl import FedAvgServer, FedAvgClient


# one in ['ours', 'DSGD', 'SISA', 'Fed']
Framework = 'DSGD'

# one in ['MNIST', 'FMNIST', 'CIFAR10']
DatasetName = 'MNIST'

# one in ['resnet', 'mobilenet']
ModelType = 'mobilenet'

# one in [1,0]. 1 - Heterogeneous;  0 - Homogeneous
Heterogeneous = 1

In [2]:
train_set_o = None
test_set_o = None
num_classes = 10
device_num = 6
num_channel = 1
ref_size = 10000
train_test_total_size = int(60000/device_num)
test_ratio = 0.2
CIFAR10_segmentation = 0

train_batch_size = 256 if ModelType == 'resnet' else 32
save_path = './checkpoint'
data_path = '../data'
log_path = '../log/{}_{}_{}.txt'.format(Framework, ModelType, DatasetName)

my_seed = 1
torch.cuda.manual_seed(my_seed)
epoch_num = 100

if DatasetName == 'CIFAR10':
    train_set_o = datasets.CIFAR10(data_path, train=True, download=True)
    test_set_o = datasets.CIFAR10(data_path, train=False, download=True)
    device_num = 5
    num_channel = 3
    epoch_num = 400
    train_test_total_size = int(50000/device_num)
    CIFAR10_segmentation = 1
    

elif DatasetName == 'MNIST':    
    train_set_o = datasets.MNIST(data_path, train=True, download=True)
    test_set_o = datasets.MNIST(data_path, train=False, download=True)
    epoch_num = 200

elif DatasetName == 'FMNIST':    
    train_set_o = datasets.FashionMNIST(data_path, train=True, download=True)
    test_set_o = datasets.FashionMNIST(data_path, train=False, download=True)
    epoch_num = 300

train_set_o.transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,)),
])

test_set_o.transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,)),
])

ref_set = Subset(test_set_o, range(0,int(ref_size)))
ref_loader = torch.utils.data.DataLoader(ref_set, batch_size=train_batch_size,
                                         shuffle=False, num_workers=0)

In [3]:
# data manipulation
device_dict = {}
loader_dict = {}
print('Device', end='\t')
for class_id in range(num_classes):
    print('Class'+str(class_id), end='\t')
print('SUM')

for device_id in range(device_num):
    range_start = train_test_total_size * device_id
    range_end = range_start + train_test_total_size
    
    # remove one class from each local dataset
    class_to_remove = torch.tensor(device_id%10)
    indices = (torch.tensor(train_set_o.targets[range_start:range_end])[..., None] !=
               class_to_remove).any(-1).nonzero(as_tuple=True)[0]
    
    # split train&test
    train_test_border = int((1-test_ratio)*len(indices))
    train_set = Subset(train_set_o, indices[:train_test_border]+range_start)
    test_set = Subset(train_set_o, indices[train_test_border:]+range_start)
    train_loader = DataLoader(train_set, batch_size=train_batch_size,
                              shuffle=True, num_workers=0)
    test_loader = DataLoader(test_set, batch_size=train_batch_size,
                             shuffle=True, num_workers=0)
    
    loader_dict[device_id] = [train_loader, test_loader]
    print_label_stat(device_id, train_set, num_classes)
    
    
# initialize devices
for device_id in range(device_num):
    gpu_id = 0
    device_dict[device_id] = DSGD_Device(device_id, gpu_id, num_classes, num_channel)
    
    
# assign neighbors
for k, v in device_dict.items():
    # full neighbor list
    neighbor_list = [x for x in device_dict.keys() if x != k]  
    v.neighbor_list = neighbor_list

Device	Class0	Class1	Class2	Class3	Class4	Class5	Class6	Class7	Class8	Class9	SUM
D-0	0	746	648	681	681	575	673	715	614	663	5996
D-1	657	0	648	699	641	593	656	699	644	675	5912
D-2	649	768	0	701	674	582	655	660	631	684	6004
D-3	655	763	663	0	640	613	682	685	649	668	6018
D-4	653	758	665	680	0	592	673	675	654	662	6012
D-5	646	751	681	677	627	0	642	707	675	657	6063


In [4]:
# heterogenerous scenario 
if Heterogeneous == 1:
    # heterogeneous scenario
    write_log(log_path, 'heterogeneous')
    for k, v in device_dict.items():
        v.model = ResNet8(num_channel) if ModelType == 'resnet' else MobileNet_S(num_channel)
        v.model.cuda(v.gpu_id)
        v.optimizer = optim.Adam(v.model.parameters(), lr=0.01)
else:
    # homogeneous scenario
    write_log(log_path, 'homogeneous')
    for k, v in device_dict.items():
        v.model = ResNet50(num_channel) if ModelType == 'resnet' else MobileNet_L(num_channel)
        v.model.cuda(v.gpu_id)
        v.optimizer = optim.Adam(v.model.parameters(), lr=0.01)

In [None]:
write_log(log_path, time.ctime(time.time()))
for epoch in range(epoch_num):
    for device_id, device in device_dict.items():
        device.train_local_model(num_iter=3, local_loader=loader_dict[device_id][0])
        device.communication(frac=1, gamma=0.9, device_dict=device_dict)
    metric = []
    for device_id, device in device_dict.items():
        metric.append(device.validate_local_model(test_loader=loader_dict[device_id][1]))
    metric_arr=np.array(metric)
    log_txt = '{:.4f}\t{:.4f}\t{:.4f}\t{:.4f}\t'.format(np.mean(metric_arr, axis=0)[0], 
                                      np.mean(metric_arr, axis=0)[1],np.mean(metric_arr, axis=0)[2],
                                      np.mean(metric_arr, axis=0)[3])
    print('mean: '+ log_txt)
    write_log(log_path, log_txt)

Client 0	Epoch:  0		Loss: 0.94650245
Client 1	Epoch:  0		Loss: 0.88405657
Client 2	Epoch:  0		Loss: 0.86257017
Client 3	Epoch:  0		Loss: 1.80002189
Client 4	Epoch:  0		Loss: 0.69426745
Client 5	Epoch:  0		Loss: 0.60473144
Server - Avg_loss: 0.0834, Acc: 160.0/1499 (0.1067)
Server - Avg_loss: 0.1006, Acc: 143.0/1479 (0.0967)
Server - Avg_loss: 0.0805, Acc: 161.0/1501 (0.1073)
Server - Avg_loss: 0.0919, Acc: 158.0/1505 (0.1050)
Server - Avg_loss: 0.0798, Acc: 158.0/1504 (0.1051)
Server - Avg_loss: 0.0887, Acc: 174.0/1516 (0.1148)
mean: 0.1059	0.0121	0.1109	0.0214	
Client 0	Epoch:  0		Loss: 0.95900005
Client 1	Epoch:  0		Loss: 0.51814646
Client 2	Epoch:  0		Loss: 0.38900074
Client 3	Epoch:  0		Loss: 0.57502472
Client 4	Epoch:  0		Loss: 0.44001913
Client 5	Epoch:  0		Loss: 0.33357447
Server - Avg_loss: 0.1072, Acc: 160.0/1499 (0.1067)
Server - Avg_loss: 0.1043, Acc: 143.0/1479 (0.0967)
Server - Avg_loss: 0.0906, Acc: 161.0/1501 (0.1073)
Server - Avg_loss: 0.1011, Acc: 158.0/1505 (0.1050)
S

Server - Avg_loss: 0.3927, Acc: 156.0/1479 (0.1055)
Server - Avg_loss: 0.3691, Acc: 339.0/1501 (0.2258)
Server - Avg_loss: 0.3221, Acc: 199.0/1505 (0.1322)
Server - Avg_loss: 0.4007, Acc: 185.0/1504 (0.1230)
Server - Avg_loss: 0.4751, Acc: 166.0/1516 (0.1095)
mean: 0.1381	0.0284	0.1348	0.0404	
Client 0	Epoch:  0		Loss: 0.07398232
Client 1	Epoch:  0		Loss: 0.31012046
Client 2	Epoch:  0		Loss: 0.34553576
Client 3	Epoch:  0		Loss: 0.00193233
Client 4	Epoch:  0		Loss: 0.09379441
Client 5	Epoch:  0		Loss: 0.09410806
Server - Avg_loss: 0.3809, Acc: 304.0/1499 (0.2028)
Server - Avg_loss: 0.4371, Acc: 156.0/1479 (0.1055)
Server - Avg_loss: 0.4557, Acc: 177.0/1501 (0.1179)
Server - Avg_loss: 0.3224, Acc: 262.0/1505 (0.1741)
Server - Avg_loss: 0.4117, Acc: 334.0/1504 (0.2221)
Server - Avg_loss: 0.4353, Acc: 166.0/1516 (0.1095)
mean: 0.1553	0.0546	0.1481	0.0613	
Client 0	Epoch:  0		Loss: 0.38030398
Client 1	Epoch:  0		Loss: 0.07640021
Client 2	Epoch:  0		Loss: 0.21259353
Client 3	Epoch:  0		Loss:

Client 0	Epoch:  0		Loss: 0.02933557
Client 1	Epoch:  0		Loss: 0.09209902
Client 2	Epoch:  0		Loss: 0.06416561
Client 3	Epoch:  0		Loss: 0.01941312
Client 4	Epoch:  0		Loss: 0.20769314
Client 5	Epoch:  0		Loss: 0.19593011
Server - Avg_loss: 0.1607, Acc: 342.0/1499 (0.2282)
Server - Avg_loss: 0.3654, Acc: 205.0/1479 (0.1386)
Server - Avg_loss: 0.3078, Acc: 178.0/1501 (0.1186)
Server - Avg_loss: 0.2812, Acc: 298.0/1505 (0.1980)
Server - Avg_loss: 0.3529, Acc: 180.0/1504 (0.1197)
Server - Avg_loss: 0.3171, Acc: 166.0/1516 (0.1095)
mean: 0.1521	0.0467	0.1484	0.0552	
Client 0	Epoch:  0		Loss: 0.04646258
Client 1	Epoch:  0		Loss: 0.04637934
Client 2	Epoch:  0		Loss: 0.47367841
Client 3	Epoch:  0		Loss: 0.28427440
Client 4	Epoch:  0		Loss: 0.18166234
Client 5	Epoch:  0		Loss: 0.06352603
Server - Avg_loss: 0.2089, Acc: 187.0/1499 (0.1247)
Server - Avg_loss: 0.3791, Acc: 283.0/1479 (0.1913)
Server - Avg_loss: 0.2592, Acc: 192.0/1501 (0.1279)
Server - Avg_loss: 0.2697, Acc: 277.0/1505 (0.1841)
S

In [None]:
# write_log(log_path, time.ctime(time.time()))
for epoch in range(epoch_num):
    for device_id, device in device_dict.items():
        device.train_local_model(num_iter=3, local_loader=loader_dict[device_id][0])
        device.communication(frac=1, gamma=0.9, device_dict=device_dict)
    metric = []
    for device_id, device in device_dict.items():
        metric.append(device.validate_local_model(test_loader=loader_dict[device_id][1]))
    metric_arr=np.array(metric)
    log_txt = '{:.4f}\t{:.4f}\t{:.4f}\t{:.4f}\t'.format(np.mean(metric_arr, axis=0)[0], 
                                      np.mean(metric_arr, axis=0)[1],np.mean(metric_arr, axis=0)[2],
                                      np.mean(metric_arr, axis=0)[3])
    print('mean: '+ log_txt)
    write_log(log_path, log_txt)