In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Subset, DataLoader
from torchvision import datasets, transforms
import random
import time
import sys

sys.path.append("..")
from utils.tool import GPU_info, write_log, print_label_stat
from models.resnet import ResNet8, ResNet18, ResNet50
from models.mobilenet import MobileNet_S, MobileNet_M, MobileNet_L
from framework.ours import Device
from framework.DSGD import DSGD_Device
from framework.SISA import SISA_Device
from framework.FedAvgUnl import FedAvgServer, FedAvgClient


# one in ['ours', 'DSGD', 'SISA', 'Fed']
Framework = 'ours'

# one in ['MNIST', 'FMNIST', 'CIFAR10']
DatasetName = 'MNIST'

# one in ['resnet', 'mobilenet']
ModelType = 'mobilenet'

# one in [1,0]. 1 - Heterogeneous;  0 - Homogeneous
Heterogeneous = 1

In [2]:
train_set_o = None
test_set_o = None
num_classes = 10
device_num = 5
num_channel = 1
ref_size = 10000
train_test_total_size = int(50000/device_num)
test_ratio = 0.2

train_batch_size = 256 if ModelType == 'resnet' else 32
save_path = './checkpoint'
data_path = '../data'
log_path = '../log/unlearn_{}_{}_{}.txt'.format(Framework, ModelType, DatasetName)

my_seed = 1
torch.cuda.manual_seed(my_seed)

if DatasetName == 'CIFAR10':
    train_set_o = datasets.CIFAR10(data_path, train=True, download=True)
    test_set_o = datasets.CIFAR10(data_path, train=False, download=True)
    num_channel = 3

elif DatasetName == 'MNIST':    
    train_set_o = datasets.MNIST(data_path, train=True, download=True)
    test_set_o = datasets.MNIST(data_path, train=False, download=True)

elif DatasetName == 'FMNIST':    
    train_set_o = datasets.FashionMNIST(data_path, train=True, download=True)
    test_set_o = datasets.FashionMNIST(data_path, train=False, download=True)

train_set_o.transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,)),
])

test_set_o.transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,)),
])

ref_set = Subset(test_set_o, range(0,int(ref_size)))
ref_loader = torch.utils.data.DataLoader(ref_set, batch_size=train_batch_size,
                                         shuffle=False, num_workers=0)

In [3]:
# data manipulation
device_dict = {}
loader_dict = {}
print('Device', end='\t')
for class_id in range(num_classes):
    print('Class'+str(class_id), end='\t')
print('SUM')
for device_id in range(device_num):
    range_start = train_test_total_size * device_id
    range_end = range_start + train_test_total_size
    # remove one class from each local dataset
    class_to_remove = torch.tensor(device_id%10)
    indices = (torch.tensor(train_set_o.targets[range_start:range_end])[..., None] !=
               class_to_remove).any(-1).nonzero(as_tuple=True)[0]
    # split train&test
    train_test_border = int((1-test_ratio)*len(indices))
    train_set = Subset(train_set_o, indices[:train_test_border]+range_start)
    test_set = Subset(train_set_o, indices[train_test_border:]+range_start)
    train_loader = DataLoader(train_set, batch_size=train_batch_size,
                              shuffle=True, num_workers=0)
    test_loader = DataLoader(test_set, batch_size=train_batch_size,
                             shuffle=True, num_workers=0) 
    loader_dict[device_id] = [train_loader, test_loader]
    print_label_stat(device_id, train_set, num_classes)

# initialize devices
for device_id in range(device_num):
    gpu_id = 0
    device_dict[device_id] = Device(device_id, gpu_id, num_classes, num_channel)

Device	Class0	Class1	Class2	Class3	Class4	Class5	Class6	Class7	Class8	Class9	SUM
D-0	0	905	788	811	798	701	797	852	752	795	7199
D-1	782	0	749	851	767	724	757	830	788	828	7076
D-2	757	919	0	819	781	751	803	813	759	782	7184
D-3	758	905	798	0	783	710	794	819	798	828	7193
D-4	804	901	817	807	0	719	773	842	792	785	7240


In [4]:
# assign neighbors
for k, v in device_dict.items():
    # full neighbor list SMLLL'
    neighbor_list = [x for x in device_dict.keys() if x != k]
    v.neighbor_list = neighbor_list
    
for k, v in device_dict.items():
    print('D_{}: {}'.format(k, v.neighbor_list))

D_0: [1, 2, 3, 4]
D_1: [0, 2, 3, 4]
D_2: [0, 1, 3, 4]
D_3: [0, 1, 2, 4]
D_4: [0, 1, 2, 3]


In [5]:
# heterogenerous scenario 
write_log(log_path, DatasetName+' heterogenerous')
if Heterogeneous == 1:
    # heterogeneous scenario
    write_log(log_path, 'heterogeneous')
    for device_id in range(device_num):
        gpu_id = 0
        if device_id < 1:
            device_dict[device_id].main_model = ResNet8(num_channel) if ModelType == 'resnet' else MobileNet_S(num_channel)
        elif device_id < 2:
            device_dict[device_id].main_model = ResNet18(num_channel) if ModelType == 'resnet' else MobileNet_M(num_channel)
        else:
            device_dict[device_id].main_model = ResNet50(num_channel) if ModelType == 'resnet' else MobileNet_L(num_channel)
        device_dict[device_id].main_model.cuda(gpu_id)

else:
    # homogeneous scenario
    write_log(log_path, 'homogeneous')
    for device_id in range(device_num):
        gpu_id = 0
        device_dict[device_id].main_model = ResNet50(num_channel) if ModelType == 'resnet' else MobileNet_L(num_channel)
        device_dict[device_id].main_model.cuda(gpu_id)

In [18]:
# train local main models
num_iter = 1000 if ModelType == 'resnet' else 501
metric = []
for k, v in device_dict.items():
#     v.train_main_model(num_iter, loader_dict[k][0])
    metric.append(v.validate_main_model(loader_dict[k][1]))
    v.update_soft_label(ref_loader)
    GPU_info([v.gpu_id])
metric_arr = np.array(metric)
log_txt = 'Avg_acc: {:.4f}'.format(np.mean(metric_arr, axis=0)[0])
print(log_txt)

Device: 0 Val_main - Avg_loss: 0.1614, Acc: 1745.0/1800 (0.9694)
GPU0-0.0563G  
Device: 1 Val_main - Avg_loss: 0.0725, Acc: 1749.0/1770 (0.9881)
GPU0-0.0563G  
Device: 2 Val_main - Avg_loss: 0.0502, Acc: 1778.0/1797 (0.9894)
GPU0-0.0563G  
Device: 3 Val_main - Avg_loss: 0.0562, Acc: 1779.0/1799 (0.9889)
GPU0-0.0563G  
Device: 4 Val_main - Avg_loss: 0.0515, Acc: 1790.0/1810 (0.9890)
GPU0-0.0563G  
Avg_acc: 0.9850


In [19]:
# train seed models
T_ = 9
num_iter = 800 if ModelType == 'resnet' else 3001
for k, v in device_dict.items():
    v.seed_model = ResNet8(num_channel) if ModelType == 'resnet' else MobileNet_S(num_channel)
    v.seed_model.cuda(gpu_id) 
    v.train_seed_model(ref_set, num_iter=num_iter, batch_size=int(ref_size*0.2), T=T_)
    print('\n')

Device:0 Trn_Seed - Epoch:   0 	Loss: 0.900630
Device:0 Trn_Seed - Epoch:  50 	Loss: 0.730935
Device:0 Trn_Seed - Epoch: 100 	Loss: 0.545694
Device:0 Trn_Seed - Epoch: 150 	Loss: 0.388388
Device:0 Trn_Seed - Epoch: 200 	Loss: 0.285739
Device:0 Trn_Seed - Epoch: 250 	Loss: 0.219916
Device:0 Trn_Seed - Epoch: 300 	Loss: 0.173411
Device:0 Trn_Seed - Epoch: 350 	Loss: 0.144517
Device:0 Trn_Seed - Epoch: 400 	Loss: 0.122088
Device:0 Trn_Seed - Epoch: 450 	Loss: 0.109289
Device:0 Trn_Seed - Epoch: 500 	Loss: 0.092158
Device:0 Trn_Seed - Epoch: 550 	Loss: 0.081440
Device:0 Trn_Seed - Epoch: 600 	Loss: 0.073397
Device:0 Trn_Seed - Epoch: 650 	Loss: 0.065570
Device:0 Trn_Seed - Epoch: 700 	Loss: 0.059073
Device:0 Trn_Seed - Epoch: 750 	Loss: 0.056906
Device:0 Trn_Seed - Epoch: 800 	Loss: 0.053968
Device:0 Trn_Seed - Epoch: 850 	Loss: 0.051793
Device:0 Trn_Seed - Epoch: 900 	Loss: 0.052474
Device:0 Trn_Seed - Epoch: 950 	Loss: 0.049455
Device:0 Trn_Seed - Epoch: 1000 	Loss: 0.045658
Device:0 Trn

Device:2 Trn_Seed - Epoch: 2500 	Loss: 0.057213
Device:2 Trn_Seed - Epoch: 2550 	Loss: 0.058556
Device:2 Trn_Seed - Epoch: 2600 	Loss: 0.058231
Device:2 Trn_Seed - Epoch: 2650 	Loss: 0.058025
Device:2 Trn_Seed - Epoch: 2700 	Loss: 0.056436
Device:2 Trn_Seed - Epoch: 2750 	Loss: 0.059217
Device:2 Trn_Seed - Epoch: 2800 	Loss: 0.055289
Device:2 Trn_Seed - Epoch: 2850 	Loss: 0.055426
Device:2 Trn_Seed - Epoch: 2900 	Loss: 0.054843
Device:2 Trn_Seed - Epoch: 2950 	Loss: 0.055026
Device:2 Trn_Seed - Epoch: 3000 	Loss: 0.054920


Device:3 Trn_Seed - Epoch:   0 	Loss: 1.068792
Device:3 Trn_Seed - Epoch:  50 	Loss: 0.888920
Device:3 Trn_Seed - Epoch: 100 	Loss: 0.696302
Device:3 Trn_Seed - Epoch: 150 	Loss: 0.551328
Device:3 Trn_Seed - Epoch: 200 	Loss: 0.464246
Device:3 Trn_Seed - Epoch: 250 	Loss: 0.383986
Device:3 Trn_Seed - Epoch: 300 	Loss: 0.304591
Device:3 Trn_Seed - Epoch: 350 	Loss: 0.251972
Device:3 Trn_Seed - Epoch: 400 	Loss: 0.222261
Device:3 Trn_Seed - Epoch: 450 	Loss: 0.189559


In [28]:
# assign neighbors
for k, v in device_dict.items():
    # full neighbor list
    neighbor_list = [x for x in device_dict.keys() if x != k]
    v.neighbor_list = neighbor_list
    
for k, v in device_dict.items():
    print('D_{}: {}'.format(k, v.neighbor_list))

D_0: [1, 2, 3, 4]
D_1: [0, 2, 3, 4]
D_2: [0, 1, 3, 4]
D_3: [0, 1, 2, 4]
D_4: [0, 1, 2, 3]


In [29]:
# overall performance
write_log(log_path, time.ctime(time.time())+' SMLLL')
write_log(log_path, 'AvgAcc\tAvgPre\tAvgRec\tAvgf1')
for rho in range(0,11):
    metric = []
    for k, v in device_dict.items():
        metric.append(v.validate_ensamble(test_loader=loader_dict[k][1], device_dict=device_dict, rho=rho/10))
    metric_arr=np.array(metric)
    log_txt = 'rho = {:.2f}   Avg_acc: {:.4f}'.format(rho/10, np.mean(metric_arr, axis=0)[0])
    print(log_txt)
    write_log(log_path, '{:.4f}\t{:.4f}\t{:.4f}\t{:.4f}'.format(np.mean(metric_arr, axis=0)[0], 
                                                                np.mean(metric_arr, axis=0)[1],
                                                                np.mean(metric_arr, axis=0)[2],
                                                                np.mean(metric_arr, axis=0)[3]))

Device:  0 Val_ensamble - Acc: 1745.0/1800 (0.9694)
Device:  1 Val_ensamble - Acc: 1749.0/1770 (0.9881)
Device:  2 Val_ensamble - Acc: 1778.0/1797 (0.9894)
Device:  3 Val_ensamble - Acc: 1779.0/1799 (0.9889)
Device:  4 Val_ensamble - Acc: 1790.0/1810 (0.9890)
rho = 0.00   Avg_acc: 0.9850
Device:  0 Val_ensamble - Acc: 1748.0/1800 (0.9711)
Device:  1 Val_ensamble - Acc: 1750.0/1770 (0.9887)
Device:  2 Val_ensamble - Acc: 1778.0/1797 (0.9894)
Device:  3 Val_ensamble - Acc: 1777.0/1799 (0.9878)
Device:  4 Val_ensamble - Acc: 1795.0/1810 (0.9917)
rho = 0.10   Avg_acc: 0.9857
Device:  0 Val_ensamble - Acc: 1750.0/1800 (0.9722)
Device:  1 Val_ensamble - Acc: 1750.0/1770 (0.9887)
Device:  2 Val_ensamble - Acc: 1781.0/1797 (0.9911)
Device:  3 Val_ensamble - Acc: 1778.0/1799 (0.9883)
Device:  4 Val_ensamble - Acc: 1794.0/1810 (0.9912)
rho = 0.20   Avg_acc: 0.9863
Device:  0 Val_ensamble - Acc: 1757.0/1800 (0.9761)
Device:  1 Val_ensamble - Acc: 1753.0/1770 (0.9904)
Device:  2 Val_ensamble - Acc

In [30]:
# assign neighbors
for k, v in device_dict.items():
    # full neighbor list
    neighbor_list = [x for x in device_dict.keys() if x != k]
    neighbor_list = neighbor_list[:-1]
    if k == 4:
        neighbor_list = []
    v.neighbor_list = neighbor_list
    
for k, v in device_dict.items():
    print('D_{}: {}'.format(k, v.neighbor_list))

D_0: [1, 2, 3]
D_1: [0, 2, 3]
D_2: [0, 1, 3]
D_3: [0, 1, 2]
D_4: []


In [31]:
# overall performance
write_log(log_path, time.ctime(time.time())+' SMLL')
write_log(log_path, 'AvgAcc\tAvgPre\tAvgRec\tAvgf1')
for rho in range(0,11):
    metric = []
    for k, v in device_dict.items():
        if k != 4:
            metric.append(v.validate_ensamble(test_loader=loader_dict[k][1], device_dict=device_dict, rho=rho/10))
    metric_arr=np.array(metric)
    log_txt = 'rho = {:.2f}   Avg_acc: {:.4f}'.format(rho/10, np.mean(metric_arr, axis=0)[0])
    print(log_txt)
    write_log(log_path, '{:.4f}\t{:.4f}\t{:.4f}\t{:.4f}'.format(np.mean(metric_arr, axis=0)[0], 
                                                                np.mean(metric_arr, axis=0)[1],
                                                                np.mean(metric_arr, axis=0)[2],
                                                                np.mean(metric_arr, axis=0)[3]))

Device:  0 Val_ensamble - Acc: 1745.0/1800 (0.9694)
Device:  1 Val_ensamble - Acc: 1749.0/1770 (0.9881)
Device:  2 Val_ensamble - Acc: 1778.0/1797 (0.9894)
Device:  3 Val_ensamble - Acc: 1779.0/1799 (0.9889)
rho = 0.00   Avg_acc: 0.9840
Device:  0 Val_ensamble - Acc: 1744.0/1800 (0.9689)
Device:  1 Val_ensamble - Acc: 1751.0/1770 (0.9893)
Device:  2 Val_ensamble - Acc: 1778.0/1797 (0.9894)
Device:  3 Val_ensamble - Acc: 1778.0/1799 (0.9883)
rho = 0.10   Avg_acc: 0.9840
Device:  0 Val_ensamble - Acc: 1749.0/1800 (0.9717)
Device:  1 Val_ensamble - Acc: 1750.0/1770 (0.9887)
Device:  2 Val_ensamble - Acc: 1778.0/1797 (0.9894)
Device:  3 Val_ensamble - Acc: 1777.0/1799 (0.9878)
rho = 0.20   Avg_acc: 0.9844
Device:  0 Val_ensamble - Acc: 1751.0/1800 (0.9728)
Device:  1 Val_ensamble - Acc: 1749.0/1770 (0.9881)
Device:  2 Val_ensamble - Acc: 1778.0/1797 (0.9894)
Device:  3 Val_ensamble - Acc: 1777.0/1799 (0.9878)
rho = 0.30   Avg_acc: 0.9845
Device:  0 Val_ensamble - Acc: 1752.0/1800 (0.9733)


In [32]:
# assign neighbors
for k, v in device_dict.items():
    # full neighbor list
    neighbor_list = [x for x in device_dict.keys() if x != k]
    neighbor_list = [x for x in neighbor_list if x not in [3,4]]
    if k == 4:
        neighbor_list = []
    if k == 3:
        neighbor_list = []
    v.neighbor_list = neighbor_list
    
for k, v in device_dict.items():
    print('D_{}: {}'.format(k, v.neighbor_list))

D_0: [1, 2]
D_1: [0, 2]
D_2: [0, 1]
D_3: []
D_4: []


In [33]:
# overall performance
write_log(log_path, time.ctime(time.time())+' SML')
write_log(log_path, 'AvgAcc\tAvgPre\tAvgRec\tAvgf1')
for rho in range(0,11):
    metric = []
    for k, v in device_dict.items():
        if k not in [3,4]:
            metric.append(v.validate_ensamble(test_loader=loader_dict[k][1], device_dict=device_dict, rho=rho/10))
    metric_arr=np.array(metric)
    log_txt = 'rho = {:.2f}   Avg_acc: {:.4f}'.format(rho/10, np.mean(metric_arr, axis=0)[0])
    print(log_txt)
    write_log(log_path, '{:.4f}\t{:.4f}\t{:.4f}\t{:.4f}'.format(np.mean(metric_arr, axis=0)[0], 
                                                                np.mean(metric_arr, axis=0)[1],
                                                                np.mean(metric_arr, axis=0)[2],
                                                                np.mean(metric_arr, axis=0)[3]))

Device:  0 Val_ensamble - Acc: 1745.0/1800 (0.9694)
Device:  1 Val_ensamble - Acc: 1749.0/1770 (0.9881)
Device:  2 Val_ensamble - Acc: 1778.0/1797 (0.9894)
rho = 0.00   Avg_acc: 0.9823
Device:  0 Val_ensamble - Acc: 1745.0/1800 (0.9694)
Device:  1 Val_ensamble - Acc: 1748.0/1770 (0.9876)
Device:  2 Val_ensamble - Acc: 1778.0/1797 (0.9894)
rho = 0.10   Avg_acc: 0.9821
Device:  0 Val_ensamble - Acc: 1745.0/1800 (0.9694)
Device:  1 Val_ensamble - Acc: 1746.0/1770 (0.9864)
Device:  2 Val_ensamble - Acc: 1780.0/1797 (0.9905)
rho = 0.20   Avg_acc: 0.9821
Device:  0 Val_ensamble - Acc: 1748.0/1800 (0.9711)
Device:  1 Val_ensamble - Acc: 1743.0/1770 (0.9847)
Device:  2 Val_ensamble - Acc: 1778.0/1797 (0.9894)
rho = 0.30   Avg_acc: 0.9818
Device:  0 Val_ensamble - Acc: 1740.0/1800 (0.9667)
Device:  1 Val_ensamble - Acc: 1742.0/1770 (0.9842)
Device:  2 Val_ensamble - Acc: 1779.0/1797 (0.9900)
rho = 0.40   Avg_acc: 0.9803
Device:  0 Val_ensamble - Acc: 1735.0/1800 (0.9639)
Device:  1 Val_ensamble

In [34]:
# assign neighbors
for k, v in device_dict.items():
    # full neighbor list
    neighbor_list = [x for x in device_dict.keys() if x != k]
    neighbor_list = [x for x in neighbor_list if x not in [3]]
    if k == 3:
        neighbor_list = []
    v.neighbor_list = neighbor_list
    
for k, v in device_dict.items():
    print('D_{}: {}'.format(k, v.neighbor_list))

D_0: [1, 2, 4]
D_1: [0, 2, 4]
D_2: [0, 1, 4]
D_3: []
D_4: [0, 1, 2]


In [35]:
# overall performance
write_log(log_path, time.ctime(time.time())+' SMLL')
write_log(log_path, 'AvgAcc\tAvgPre\tAvgRec\tAvgf1')
for rho in range(0,11):
    metric = []
    for k, v in device_dict.items():
        if k not in [3]:
            metric.append(v.validate_ensamble(test_loader=loader_dict[k][1], device_dict=device_dict, rho=rho/10))
    metric_arr=np.array(metric)
    log_txt = 'rho = {:.2f}   Avg_acc: {:.4f}'.format(rho/10, np.mean(metric_arr, axis=0)[0])
    print(log_txt)
    write_log(log_path, '{:.4f}\t{:.4f}\t{:.4f}\t{:.4f}'.format(np.mean(metric_arr, axis=0)[0], 
                                                                np.mean(metric_arr, axis=0)[1],
                                                                np.mean(metric_arr, axis=0)[2],
                                                                np.mean(metric_arr, axis=0)[3]))

Device:  0 Val_ensamble - Acc: 1745.0/1800 (0.9694)
Device:  1 Val_ensamble - Acc: 1749.0/1770 (0.9881)
Device:  2 Val_ensamble - Acc: 1778.0/1797 (0.9894)
Device:  4 Val_ensamble - Acc: 1790.0/1810 (0.9890)
rho = 0.00   Avg_acc: 0.9840
Device:  0 Val_ensamble - Acc: 1750.0/1800 (0.9722)
Device:  1 Val_ensamble - Acc: 1750.0/1770 (0.9887)
Device:  2 Val_ensamble - Acc: 1779.0/1797 (0.9900)
Device:  4 Val_ensamble - Acc: 1794.0/1810 (0.9912)
rho = 0.10   Avg_acc: 0.9855
Device:  0 Val_ensamble - Acc: 1755.0/1800 (0.9750)
Device:  1 Val_ensamble - Acc: 1746.0/1770 (0.9864)
Device:  2 Val_ensamble - Acc: 1782.0/1797 (0.9917)
Device:  4 Val_ensamble - Acc: 1793.0/1810 (0.9906)
rho = 0.20   Avg_acc: 0.9859
Device:  0 Val_ensamble - Acc: 1756.0/1800 (0.9756)
Device:  1 Val_ensamble - Acc: 1748.0/1770 (0.9876)
Device:  2 Val_ensamble - Acc: 1782.0/1797 (0.9917)
Device:  4 Val_ensamble - Acc: 1791.0/1810 (0.9895)
rho = 0.30   Avg_acc: 0.9861
Device:  0 Val_ensamble - Acc: 1756.0/1800 (0.9756)
