/
federated_main.py
147 lines (104 loc) · 5.49 KB
/
federated_main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Python version: 3.6
import os
import copy
import time
import pickle
import numpy as np
from tqdm import tqdm
import math
import torch
from torch import nn
from tensorboardX import SummaryWriter
from options import args_parser
from update import LocalUpdate, update_model_inplace, test_inference
from utils import get_model, get_dataset, average_weights, exp_details, average_parameter_delta
if __name__ == '__main__':
start_time = time.time()
args = args_parser()
exp_details(args)
# define paths
# out_dir_name = args.model + args.dataset + args.optimizer + '_lr' + str(args.lr) + '_locallr' + str(args.local_lr) + '_localep' + str(args.local_ep) +'_localbs' + str(args.local_bs) + '_eps' + str(args.eps)
file_name = '/{}_{}_{}_llr[{}]_glr[{}]_eps[{}]_le[{}]_bs[{}]_iid[{}]_mi[{}]_frac[{}].pkl'.\
format(args.dataset, args.model, args.optimizer,
args.local_lr, args.lr, args.eps,
args.local_ep, args.local_bs, args.iid, args.max_init, args.frac)
logger = SummaryWriter('./logs/'+file_name)
device = torch.device('cuda:{}'.format(args.gpu) if torch.cuda.is_available() else "cpu")
torch.set_num_threads(1) # limit cpu use
print ('-- pytorch version: ', torch.__version__)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
if device != 'cpu':
torch.cuda.manual_seed(args.seed)
if not os.path.exists(args.outfolder):
os.mkdir(args.outfolder)
# load dataset and user groups
train_dataset, test_dataset, num_classes, user_groups = get_dataset(args)
# Set the model to train and send it to device.
global_model = get_model(args.model, args.dataset, train_dataset[0][0].shape, num_classes)
global_model.to(device)
global_model.train()
momentum_buffer_list = []
exp_avgs = []
exp_avg_sqs = []
max_exp_avg_sqs = []
for i, p in enumerate(global_model.parameters()):
momentum_buffer_list.append(torch.zeros_like(p.data.detach().clone(), dtype=torch.float, requires_grad=False))
exp_avgs.append(torch.zeros_like(p.data.detach().clone(), dtype=torch.float, requires_grad=False))
exp_avg_sqs.append(torch.zeros_like(p.data.detach().clone(), dtype=torch.float, requires_grad=False))
max_exp_avg_sqs.append(torch.zeros_like(p.data.detach().clone(), dtype=torch.float, requires_grad=False)+args.max_init) # 1e-2
# Training
train_loss_sampled, train_loss, train_accuracy = [], [], []
test_loss, test_accuracy = [], []
start_time = time.time()
for epoch in tqdm(range(args.epochs)):
ep_time = time.time()
local_weights, local_params, local_losses = [], [], []
print(f'\n | Global Training Round : {epoch+1} |\n')
par_before = []
for p in global_model.parameters(): # get trainable parameters
par_before.append(p.data.detach().clone())
# this is to store parameters before update
w0 = global_model.state_dict() # get all parameters, includeing batch normalization related ones
global_model.train()
m = max(int(args.frac * args.num_users), 1)
idxs_users = np.random.choice(range(args.num_users), m, replace=False)
for idx in idxs_users:
local_model = LocalUpdate(args=args, dataset=train_dataset,
idxs=user_groups[idx], logger=logger)
w, p, loss = local_model.update_weights_local(
model=copy.deepcopy(global_model), global_round=epoch)
local_weights.append(copy.deepcopy(w))
local_params.append(copy.deepcopy(p))
local_losses.append(copy.deepcopy(loss))
bn_weights = average_weights(local_weights)
global_model.load_state_dict(bn_weights)
# this is to update trainable parameters via different optimizers
global_delta = average_parameter_delta(local_params, par_before) # calculate compression in this function
update_model_inplace(
global_model, par_before, global_delta, args, epoch,
momentum_buffer_list, exp_avgs, exp_avg_sqs, max_exp_avg_sqs)
# report and store loss and accuracy
# this is local training loss on sampled users
loss_avg = sum(local_losses) / len(local_losses)
train_loss.append(loss_avg)
print('Epoch Run Time: {0:0.4f} of {1} global rounds'.format(time.time()-ep_time, epoch+1))
print(f'Training Loss : {train_loss[-1]}')
logger.add_scalar('train loss', train_loss[-1], epoch)
global_model.eval()
# Test inference after completion of training
test_acc, test_ls = test_inference(args, global_model, test_dataset)
test_accuracy.append(test_acc)
test_loss.append(test_ls)
# print global training loss after every rounds
print(f'Test Loss : {test_loss[-1]}')
print(f'Test Accuracy : {test_accuracy[-1]} \n')
logger.add_scalar('test loss', test_loss[-1], epoch)
logger.add_scalar('test acc', test_accuracy[-1], epoch)
if args.save:
# Saving the objects train_loss and train_accuracy:
with open(args.outfolder + file_name, 'wb') as f:
pickle.dump([train_loss, test_loss, test_accuracy], f)
print('\n Total Run Time: {0:0.4f}'.format(time.time()-start_time))