In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Subset, DataLoader
from torchvision import datasets, transforms
import random
import time
import sys

sys.path.append("..")
from utils.tool import GPU_info, write_log, print_label_stat
from models.resnet import ResNet8, ResNet18, ResNet50
from models.mobilenet import MobileNet_S, MobileNet_M, MobileNet_L
from framework.ours import Device
from framework.DSGD import DSGD_Device
from framework.SISA import SISA_Device
from framework.FedAvgUnl import FedAvgServer, FedAvgClient


# one in ['ours', 'DSGD', 'SISA', 'Fed']
Framework = 'Fed'

# one in ['MNIST', 'FMNIST', 'CIFAR10']
DatasetName = 'FMNIST'

# one in ['resnet', 'mobilenet']
ModelType = 'mobilenet'

# one in [1,0]. 1 - Heterogeneous;  0 - Homogeneous
Heterogeneous = 0

In [2]:
train_set_o = None
test_set_o = None
num_classes = 10
device_num = 6
num_channel = 1
ref_size = 10000
train_test_total_size = int(60000/device_num)
test_ratio = 0.2
CIFAR10_segmentation = 0

train_batch_size = 256 if ModelType == 'resnet' else 32
save_path = './checkpoint'
data_path = '../data'
log_path = '../log/{}_{}_{}.txt'.format(Framework, ModelType, DatasetName)

my_seed = 1
torch.cuda.manual_seed(my_seed)
epoch_num = 100

if DatasetName == 'CIFAR10':
    train_set_o = datasets.CIFAR10(data_path, train=True, download=True)
    test_set_o = datasets.CIFAR10(data_path, train=False, download=True)
    device_num = 5
    num_channel = 3
    CIFAR10_segmentation = 1
    train_test_total_size = int(50000/device_num)
    epoch_num = 400
    

elif DatasetName == 'MNIST':    
    train_set_o = datasets.MNIST(data_path, train=True, download=True)
    test_set_o = datasets.MNIST(data_path, train=False, download=True)
    epoch_num = 100

elif DatasetName == 'FMNIST':    
    train_set_o = datasets.FashionMNIST(data_path, train=True, download=True)
    test_set_o = datasets.FashionMNIST(data_path, train=False, download=True)
    epoch_num = 300

train_set_o.transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,)),
])

test_set_o.transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,)),
])

ref_set = Subset(test_set_o, range(0,int(ref_size)))
ref_loader = torch.utils.data.DataLoader(ref_set, batch_size=train_batch_size,
                                         shuffle=False, num_workers=0)

In [3]:
# data manipulation
device_dict = {}
loader_dict = {}
print('Device', end='\t')
for class_id in range(num_classes):
    print('Class'+str(class_id), end='\t')
print('SUM')

for device_id in range(device_num):
    range_start = train_test_total_size * device_id
    range_end = range_start + train_test_total_size
    
    # remove one class from each local dataset
    class_to_remove = torch.tensor(device_id%10)
    indices = (torch.tensor(train_set_o.targets[range_start:range_end])[..., None] !=
               class_to_remove).any(-1).nonzero(as_tuple=True)[0]
    
    # split train&test
    train_test_border = int((1-test_ratio)*len(indices))
    train_set = Subset(train_set_o, indices[:train_test_border]+range_start)
    test_set = Subset(train_set_o, indices[train_test_border:]+range_start)
    train_loader = DataLoader(train_set, batch_size=train_batch_size,
                              shuffle=True, num_workers=0)
    test_loader = DataLoader(test_set, batch_size=train_batch_size,
                             shuffle=True, num_workers=0)
    
    loader_dict[device_id] = [train_loader, test_loader]
    print_label_stat(device_id, train_set, num_classes)
    
# initialize devices
for device_id in range(device_num):
    gpu_id = 0
    device_dict[device_id] = FedAvgClient(device_id, gpu_id=gpu_id, num_classes=num_classes, num_channel=num_channel)
    
Fed_server = FedAvgServer(Device_id_list = device_dict.keys(), gpu_id=gpu_id, num_classes=num_classes, num_channel=num_channel)

Device	Class0	Class1	Class2	Class3	Class4	Class5	Class6	Class7	Class8	Class9	SUM
D-0	0	720	676	687	644	658	660	676	662	663	6046
D-1	669	0	643	659	681	663	697	637	659	691	5999
D-2	687	637	0	667	665	701	697	672	632	634	5992
D-3	658	645	647	0	638	689	682	677	718	631	5985
D-4	682	654	638	652	0	643	657	676	696	686	5984
D-5	634	676	706	652	683	0	660	656	675	670	6012


In [4]:
# heterogenerous scenario 
if Heterogeneous == 1:
    # heterogeneous scenario
    write_log(log_path, 'heterogeneous')
    for k, v in device_dict.items():
        v.client_model = ResNet8(num_channel) if ModelType == 'resnet' else MobileNet_S(num_channel)
        v.client_model.cuda(v.gpu_id)
        v.optimizer = optim.Adam(v.client_model.parameters(), lr=0.01)
    Fed_server.central_model = ResNet8(num_channel) if ModelType == 'resnet' else MobileNet_S(num_channel)
    Fed_server.central_model.cuda(Fed_server.gpu_id)
else:
    # homogeneous scenario
    write_log(log_path, 'homogeneous')
    for k, v in device_dict.items():
        v.client_model = ResNet50(num_channel) if ModelType == 'resnet' else MobileNet_L(num_channel)
        v.client_model.cuda(v.gpu_id)
        v.optimizer = optim.Adam(v.client_model.parameters(), lr=0.01)
    Fed_server.central_model = ResNet50(num_channel) if ModelType == 'resnet' else MobileNet_L(num_channel)
    Fed_server.central_model.cuda(Fed_server.gpu_id)

In [None]:
write_log(log_path, time.ctime(time.time()))
write_log(log_path, 'AvgAcc\tAvgPre\tAvgRec\tAvgf1')
for epoch in range(epoch_num):
    for client_id, client in device_dict.items():
        client.update_client_model(num_iter=3, local_loader=loader_dict[client_id][0])
    Fed_server.update_central_model(frac=1, device_dict=device_dict)
    Fed_server.distribute_central_model(device_dict=device_dict)
    metric = []
    for client_id, client in device_dict.items():
        metric.append(Fed_server.validate_central_model(test_loader=loader_dict[client_id][1]))
    metric_arr=np.array(metric)
    log_txt = '{:.4f}\t{:.4f}\t{:.4f}\t{:.4f}\t'.format(np.mean(metric_arr, axis=0)[0], 
                                                        np.mean(metric_arr, axis=0)[1],
                                                        np.mean(metric_arr, axis=0)[2],
                                                        np.mean(metric_arr, axis=0)[3])
    print('epoch: {} mean: {}'.format(epoch, log_txt))
    write_log(log_path, log_txt)

Client 0	Epoch:  0		Loss: 0.74684054
Client 1	Epoch:  0		Loss: 0.73502612
Client 2	Epoch:  0		Loss: 0.25984645
Client 3	Epoch:  0		Loss: 2.48468733
Client 4	Epoch:  0		Loss: 0.48650038
Client 5	Epoch:  0		Loss: 0.59811240
Server - Avg_loss: 0.0732, Acc: 157.0/1512 (0.1038)
Server - Avg_loss: 0.0720, Acc: 151.0/1500 (0.1007)
Server - Avg_loss: 0.0723, Acc: 179.0/1499 (0.1194)
Server - Avg_loss: 0.0725, Acc: 165.0/1497 (0.1102)
Server - Avg_loss: 0.0726, Acc: 0.0/1497 (0.0000)
Server - Avg_loss: 0.0722, Acc: 135.0/1503 (0.0898)
epoch: 0 mean: 0.0873	0.0099	0.0915	0.0176	
Client 0	Epoch:  0		Loss: 0.82807499
Client 1	Epoch:  0		Loss: 1.21552014
Client 2	Epoch:  0		Loss: 1.08323014
Client 3	Epoch:  0		Loss: 1.04616618
Client 4	Epoch:  0		Loss: 0.63886571
Client 5	Epoch:  0		Loss: 0.85985363
Server - Avg_loss: 0.0731, Acc: 165.0/1512 (0.1091)
Server - Avg_loss: 0.0717, Acc: 150.0/1500 (0.1000)
Server - Avg_loss: 0.0724, Acc: 164.0/1499 (0.1094)
Server - Avg_loss: 0.0725, Acc: 154.0/1497 (0.