In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Subset, DataLoader
from torchvision import datasets, transforms
import random
import time
import sys

sys.path.append("..")
from utils.tool import GPU_info, write_log, print_label_stat
from models.resnet import ResNet8, ResNet18, ResNet50
from models.mobilenet import MobileNet_S, MobileNet_M, MobileNet_L
from framework.ours import Device
from framework.DSGD import DSGD_Device
from framework.SISA import SISA_Device
from framework.FedAvgUnl import FedAvgServer, FedAvgClient


# one in ['ours', 'DSGD', 'SISA', 'Fed']
Framework = 'DSGD'

# one in ['MNIST', 'FMNIST', 'CIFAR10']
DatasetName = 'CIFAR10'

# one in ['resnet', 'mobilenet']
ModelType = 'mobilenet'

# one in [1,0]. 1 - Heterogeneous;  0 - Homogeneous
Heterogeneous = 1

In [2]:
train_set_o = None
test_set_o = None
num_classes = 10
device_num = 5
num_channel = 1
ref_size = 10000
train_test_total_size = int(50000/device_num)
test_ratio = 0.2

train_batch_size = 256 if ModelType == 'resnet' else 32
save_path = './checkpoint'
data_path = '../data'
log_path = '../log/unlearn_{}_{}_{}.txt'.format(Framework, ModelType, DatasetName)

my_seed = 1
torch.cuda.manual_seed(my_seed)

if DatasetName == 'CIFAR10':
    train_set_o = datasets.CIFAR10(data_path, train=True, download=True)
    test_set_o = datasets.CIFAR10(data_path, train=False, download=True)
    num_channel = 3

elif DatasetName == 'MNIST':    
    train_set_o = datasets.MNIST(data_path, train=True, download=True)
    test_set_o = datasets.MNIST(data_path, train=False, download=True)

elif DatasetName == 'FMNIST':    
    train_set_o = datasets.FashionMNIST(data_path, train=True, download=True)
    test_set_o = datasets.FashionMNIST(data_path, train=False, download=True)

train_set_o.transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,)),
])

test_set_o.transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,)),
])

ref_set = Subset(test_set_o, range(0,int(ref_size)))
ref_loader = torch.utils.data.DataLoader(ref_set, batch_size=train_batch_size,
                                         shuffle=False, num_workers=0)

Files already downloaded and verified
Files already downloaded and verified


In [3]:
# data manipulation
device_dict = {}
loader_dict = {}
print('Device', end='\t')
for class_id in range(num_classes):
    print('Class'+str(class_id), end='\t')
print('SUM')
for device_id in range(device_num):
    range_start = train_test_total_size * device_id
    range_end = range_start + train_test_total_size
    # remove one class from each local dataset
    class_to_remove = torch.tensor(device_id%10)
    indices = (torch.tensor(train_set_o.targets[range_start:range_end])[..., None] !=
               class_to_remove).any(-1).nonzero(as_tuple=True)[0]
    # split train&test
    train_test_border = int((1-test_ratio)*len(indices))
    train_set = Subset(train_set_o, indices[:train_test_border]+range_start)
    test_set = Subset(train_set_o, indices[train_test_border:]+range_start)
    train_loader = DataLoader(train_set, batch_size=train_batch_size,
                              shuffle=True, num_workers=0)
    test_loader = DataLoader(test_set, batch_size=train_batch_size,
                             shuffle=True, num_workers=0) 
    loader_dict[device_id] = [train_loader, test_loader]
    print_label_stat(device_id, train_set, num_classes)
    
    
# initialize devices
for device_id in range(device_num):
    gpu_id = 0
    device_dict[device_id] = DSGD_Device(device_id, gpu_id, num_classes, num_channel)

Device	Class0	Class1	Class2	Class3	Class4	Class5	Class6	Class7	Class8	Class9	SUM
D-0	0	764	817	807	806	765	830	809	807	791	7196
D-1	793	0	817	806	802	768	813	824	782	789	7194
D-2	798	807	0	817	790	815	787	806	781	827	7228
D-3	790	770	852	0	809	824	803	781	802	788	7219
D-4	822	816	770	806	0	831	778	778	788	813	7202


In [4]:
if Heterogeneous == 1:
    # heterogeneous scenario
    write_log(log_path, 'heterogeneous')
    for device_id in range(device_num):
        gpu_id = 0
        device_dict[device_id].model = ResNet8(num_channel) if ModelType == 'resnet' else MobileNet_S(num_channel)
        device_dict[device_id].model.cuda(gpu_id)
        device_dict[device_id].optimizer = optim.Adam(device_dict[device_id].model.parameters(), lr=0.01)
else:
    # homogeneous scenario
    write_log(log_path, 'homogeneous')
    for device_id in range(device_num):
        gpu_id = 0
        device_dict[device_id].model = ResNet50(num_channel) if ModelType == 'resnet' else MobileNet_L(num_channel)
        device_dict[device_id].model.cuda(gpu_id)
        device_dict[device_id].optimizer = optim.Adam(device_dict[device_id].model.parameters(), lr=0.01)

In [5]:
print(device_dict.keys())

dict_keys([0, 1, 2, 3, 4])


In [6]:
# remove one client for rejoin
device_back_up = device_dict[4]
print(device_dict.keys())
del device_dict[4]
print(device_dict.keys())

# assign neighbors
for k, v in device_dict.items():
    # full neighbor list
    neighbor_list = [x for x in device_dict.keys() if x != k]  
    v.neighbor_list = neighbor_list

dict_keys([0, 1, 2, 3, 4])
dict_keys([0, 1, 2, 3])


In [None]:
# train models
epoch_num = 800 if ModelType == 'resnet' else 200
write_log(log_path, time.ctime(time.time()))
for epoch in range(epoch_num):
    for device_id, device in device_dict.items():
        device.train_local_model(num_iter=1, local_loader=loader_dict[device_id][0])
        device.communication(frac=1, gamma=0.9, device_dict=device_dict)
    metric = []
    for device_id, device in device_dict.items():
        metric.append(device.validate_local_model(test_loader=loader_dict[device_id][1]))
    metric_arr=np.array(metric)
    log_txt = '{:.4f}\t{:.4f}\t{:.4f}\t{:.4f}'.format(np.mean(metric_arr, axis=0)[0], 
                                      np.mean(metric_arr, axis=0)[1],np.mean(metric_arr, axis=0)[2],
                                      np.mean(metric_arr, axis=0)[3])
    print('Epoch: {}  mean: {}'.format(epoch, log_txt), end='\n\n')
    write_log(log_path, log_txt)

Client 0	Epoch:  0		Loss: 1.74429679
Client 1	Epoch:  0		Loss: 1.46573496
Client 2	Epoch:  0		Loss: 1.78396702
Client 3	Epoch:  0		Loss: 1.84331954
Server - Avg_loss: 0.0773, Acc: 209.0/1799 (0.1162)
Server - Avg_loss: 0.0722, Acc: 193.0/1799 (0.1073)
Server - Avg_loss: 0.0752, Acc: 200.0/1807 (0.1107)
Server - Avg_loss: 0.0707, Acc: 200.0/1805 (0.1108)
Epoch: 0  mean: 0.1112	0.0127	0.1108	0.0223

Client 0	Epoch:  0		Loss: 1.71681404
Client 1	Epoch:  0		Loss: 1.47704315
Client 2	Epoch:  0		Loss: 1.80626774
Client 3	Epoch:  0		Loss: 1.49480391
Server - Avg_loss: 0.0789, Acc: 223.0/1799 (0.1240)
Server - Avg_loss: 0.0738, Acc: 193.0/1799 (0.1073)
Server - Avg_loss: 0.0820, Acc: 200.0/1807 (0.1107)
Server - Avg_loss: 0.0723, Acc: 197.0/1805 (0.1091)
Epoch: 1  mean: 0.1128	0.0184	0.1135	0.0256

Client 0	Epoch:  0		Loss: 1.55770123
Client 1	Epoch:  0		Loss: 1.41021097
Client 2	Epoch:  0		Loss: 2.10281253
Client 3	Epoch:  0		Loss: 1.81956327
Server - Avg_loss: 0.0826, Acc: 219.0/1799 (0.1217

Client 3	Epoch:  0		Loss: 1.08443010
Server - Avg_loss: 0.0810, Acc: 492.0/1799 (0.2735)
Server - Avg_loss: 0.0936, Acc: 278.0/1799 (0.1545)
Server - Avg_loss: 0.1018, Acc: 308.0/1807 (0.1704)
Server - Avg_loss: 0.0989, Acc: 381.0/1805 (0.2111)
Epoch: 24  mean: 0.2024	0.1428	0.2017	0.1232

Client 0	Epoch:  0		Loss: 1.18121791
Client 1	Epoch:  0		Loss: 1.78685582
Client 2	Epoch:  0		Loss: 0.97397178
Client 3	Epoch:  0		Loss: 1.36177278
Server - Avg_loss: 0.0769, Acc: 489.0/1799 (0.2718)
Server - Avg_loss: 0.1010, Acc: 354.0/1799 (0.1968)
Server - Avg_loss: 0.0925, Acc: 367.0/1807 (0.2031)
Server - Avg_loss: 0.0892, Acc: 441.0/1805 (0.2443)
Epoch: 25  mean: 0.2290	0.1639	0.2279	0.1458

Client 0	Epoch:  0		Loss: 0.97621667
Client 1	Epoch:  0		Loss: 1.51443303
Client 2	Epoch:  0		Loss: 1.08268499
Client 3	Epoch:  0		Loss: 0.91364431
Server - Avg_loss: 0.0723, Acc: 517.0/1799 (0.2874)
Server - Avg_loss: 0.0785, Acc: 465.0/1799 (0.2585)
Server - Avg_loss: 0.1106, Acc: 322.0/1807 (0.1782)
Ser

Server - Avg_loss: 0.0576, Acc: 754.0/1805 (0.4177)
Epoch: 44  mean: 0.4298	0.3930	0.4203	0.3452

Client 0	Epoch:  0		Loss: 1.03449845
Client 1	Epoch:  0		Loss: 1.06047177
Client 2	Epoch:  0		Loss: 0.70426273
Client 3	Epoch:  0		Loss: 0.63453656
Server - Avg_loss: 0.0415, Acc: 997.0/1799 (0.5542)
Server - Avg_loss: 0.0599, Acc: 663.0/1799 (0.3685)
Server - Avg_loss: 0.0506, Acc: 811.0/1807 (0.4488)
Server - Avg_loss: 0.0581, Acc: 737.0/1805 (0.4083)
Epoch: 45  mean: 0.4450	0.4318	0.4406	0.3729

Client 0	Epoch:  0		Loss: 1.13861096
Client 1	Epoch:  0		Loss: 1.50046670
Client 2	Epoch:  0		Loss: 1.30286229
Client 3	Epoch:  0		Loss: 1.12732995
Server - Avg_loss: 0.0455, Acc: 913.0/1799 (0.5075)
Server - Avg_loss: 0.0628, Acc: 617.0/1799 (0.3430)
Server - Avg_loss: 0.0511, Acc: 805.0/1807 (0.4455)
Server - Avg_loss: 0.0550, Acc: 764.0/1805 (0.4233)
Epoch: 46  mean: 0.4298	0.4065	0.4302	0.3569

Client 0	Epoch:  0		Loss: 0.97854775
Client 1	Epoch:  0		Loss: 1.04529095
Client 2	Epoch:  0		Loss

Client 3	Epoch:  0		Loss: 0.74479580
Server - Avg_loss: 0.0419, Acc: 972.0/1799 (0.5403)
Server - Avg_loss: 0.0416, Acc: 973.0/1799 (0.5409)
Server - Avg_loss: 0.0376, Acc: 1014.0/1807 (0.5612)
Server - Avg_loss: 0.0434, Acc: 958.0/1805 (0.5307)
Epoch: 65  mean: 0.5433	0.5171	0.5389	0.4705

Client 0	Epoch:  0		Loss: 1.01860452
Client 1	Epoch:  0		Loss: 1.12457407
Client 2	Epoch:  0		Loss: 1.09065211
Client 3	Epoch:  0		Loss: 1.12241638
Server - Avg_loss: 0.0389, Acc: 1028.0/1799 (0.5714)
Server - Avg_loss: 0.0456, Acc: 957.0/1799 (0.5320)
Server - Avg_loss: 0.0342, Acc: 1079.0/1807 (0.5971)
Server - Avg_loss: 0.0410, Acc: 1032.0/1805 (0.5717)
Epoch: 66  mean: 0.5681	0.5396	0.5615	0.4990

Client 0	Epoch:  0		Loss: 1.08043218
Client 1	Epoch:  0		Loss: 0.99306530
Client 2	Epoch:  0		Loss: 0.72701007
Client 3	Epoch:  0		Loss: 1.09537959
Server - Avg_loss: 0.0370, Acc: 1043.0/1799 (0.5798)
Server - Avg_loss: 0.0423, Acc: 954.0/1799 (0.5303)
Server - Avg_loss: 0.0364, Acc: 1061.0/1807 (0.587

Server - Avg_loss: 0.0413, Acc: 999.0/1799 (0.5553)
Server - Avg_loss: 0.0354, Acc: 1071.0/1807 (0.5927)
Server - Avg_loss: 0.0389, Acc: 1064.0/1805 (0.5895)
Epoch: 85  mean: 0.5781	0.5650	0.5757	0.5162

Client 0	Epoch:  0		Loss: 0.93935221
Client 1	Epoch:  0		Loss: 1.01419330
Client 2	Epoch:  0		Loss: 0.54320562
Client 3	Epoch:  0		Loss: 0.68101335
Server - Avg_loss: 0.0367, Acc: 1069.0/1799 (0.5942)
Server - Avg_loss: 0.0372, Acc: 1057.0/1799 (0.5875)
Server - Avg_loss: 0.0343, Acc: 1099.0/1807 (0.6082)
Server - Avg_loss: 0.0377, Acc: 1090.0/1805 (0.6039)
Epoch: 86  mean: 0.5985	0.5715	0.5942	0.5344

Client 0	Epoch:  0		Loss: 1.11567032
Client 1	Epoch:  0		Loss: 1.20903444
Client 2	Epoch:  0		Loss: 0.65330976
Client 3	Epoch:  0		Loss: 0.66376394
Server - Avg_loss: 0.0379, Acc: 1061.0/1799 (0.5898)
Server - Avg_loss: 0.0396, Acc: 1040.0/1799 (0.5781)
Server - Avg_loss: 0.0316, Acc: 1151.0/1807 (0.6370)
Server - Avg_loss: 0.0407, Acc: 1067.0/1805 (0.5911)
Epoch: 87  mean: 0.5990	0.5589

In [None]:
# unlearn one client
unlearn_id = 3
write_log(log_path, '\n unlearn_id: {}'.format(unlearn_id))
del device_dict[unlearn_id]
print(device_dict.keys())

In [None]:
# reinitialize

start_time = time.time()
for i in range(10):
    device_dict[0].train_local_model(num_iter=3, local_loader=loader_dict[device_id][0])
    device_dict[0].communication(frac=1, gamma=0.9, device_dict=device_dict)
iter_time = (time.time() - start_time)/10

if Heterogeneous == 1:
    # heterogeneous scenario
    write_log(log_path, 'heterogeneous')
    for device_id in device_dict.keys():
        gpu_id = 0
        device_dict[device_id].model = ResNet8(num_channel) if ModelType == 'resnet' else MobileNet_S(num_channel)
        device_dict[device_id].model.cuda(gpu_id)
        device_dict[device_id].optimizer = optim.Adam(device_dict[device_id].model.parameters(), lr=0.01)
else:
    # homogeneous scenario
    write_log(log_path, 'homogeneous')
    for device_id in device_dict.keys():
        gpu_id = 0
        device_dict[device_id].model = ResNet50(num_channel) if ModelType == 'resnet' else MobileNet_L(num_channel)
        device_dict[device_id].model.cuda(gpu_id)
        device_dict[device_id].optimizer = optim.Adam(device_dict[device_id].model.parameters(), lr=0.01)
        
metric = []
for device_id, device in device_dict.items():
    metric.append(device.validate_local_model(test_loader=loader_dict[device_id][1]))
metric_arr=np.array(metric)
log_txt = '{:.4f}\t{:.4f}\t{:.4f}\t{:.4f}'.format(np.mean(metric_arr, axis=0)[0], 
                                  np.mean(metric_arr, axis=0)[1],np.mean(metric_arr, axis=0)[2],
                                  np.mean(metric_arr, axis=0)[3])
print('mean: '+ log_txt)
write_log(log_path, log_txt)

In [None]:
# retrain models
epoch_num = 150 if ModelType == 'resnet' else 100
write_log(log_path, time.ctime(time.time()))
iter_num = int(60/iter_time/len(device_dict))
start_time = time.time()
for epoch in range(epoch_num):
    for i in range(iter_num):
        for device_id, device in device_dict.items():
            device.train_local_model(num_iter=3, local_loader=loader_dict[device_id][0])
            device.communication(frac=1, gamma=0.9, device_dict=device_dict)
        print('remedy time consumption(s): {}'.format(time.time() - start_time))
    if time.time() - start_time < 60*(epoch+1):
        iter_num+=1
    else:
        iter_num-=1
            
    metric = []
    for device_id, device in device_dict.items():
        metric.append(device.validate_local_model(test_loader=loader_dict[device_id][1]))
    metric_arr=np.array(metric)
    log_txt = '{:.4f}\t{:.4f}\t{:.4f}\t{:.4f}'.format(np.mean(metric_arr, axis=0)[0], 
                                      np.mean(metric_arr, axis=0)[1],np.mean(metric_arr, axis=0)[2],
                                      np.mean(metric_arr, axis=0)[3])
    print('Epoch: {}  mean: {}'.format(epoch, log_txt), end='\n\n')
    write_log(log_path, log_txt)

In [None]:
# train the client for rejoin
device_back_up.train_local_model(num_iter=200, local_loader=loader_dict[4][0])

In [None]:
device_dict[unlearn_id] = device_back_up
loader_dict[unlearn_id] = loader_dict[4]
# assign neighbors
for k, v in device_dict.items():
    # full neighbor list
    neighbor_list = [x for x in device_dict.keys() if x != k]  
    v.neighbor_list = neighbor_list

metric = []
for device_id, device in device_dict.items():
    metric.append(device.validate_local_model(test_loader=loader_dict[device_id][1]))
metric_arr=np.array(metric)
log_txt = '{:.4f}\t{:.4f}\t{:.4f}\t{:.4f}'.format(np.mean(metric_arr, axis=0)[0], 
                                  np.mean(metric_arr, axis=0)[1],np.mean(metric_arr, axis=0)[2],
                                  np.mean(metric_arr, axis=0)[3])
print('Epoch: {}  mean: {}'.format(0, log_txt), end='\n\n')
write_log(log_path, log_txt)

In [None]:
# performance recovery
metric = []
epoch_num = 100 if ModelType == 'resnet' else 100
write_log(log_path, 'Recovering')
write_log(log_path, time.ctime(time.time()))
iter_num = int(10/iter_time/len(device_dict))
start_time = time.time()
for epoch in range(epoch_num):
    for i in range(iter_num):
        for device_id, device in device_dict.items():
            device.train_local_model(num_iter=3, local_loader=loader_dict[device_id][0])
            device.communication(frac=1, gamma=0.9, device_dict=device_dict)
        print('remedy time consumption(s): {}'.format(time.time() - start_time))
    if time.time() - start_time < 10*(epoch+1):
        iter_num+=1
    else:
        iter_num-=1
            
    metric = []
    for device_id, device in device_dict.items():
        metric.append(device.validate_local_model(test_loader=loader_dict[device_id][1]))
    metric_arr=np.array(metric)
    log_txt = '{:.4f}\t{:.4f}\t{:.4f}\t{:.4f}'.format(np.mean(metric_arr, axis=0)[0], 
                                      np.mean(metric_arr, axis=0)[1],np.mean(metric_arr, axis=0)[2],
                                      np.mean(metric_arr, axis=0)[3])
    print('Epoch: {}  mean: {}'.format(epoch, log_txt), end='\n\n')
    write_log(log_path, log_txt)