In the distributed setting, we study weight update algorithms(like DSGD, ADMM, and Collaborative Learning). This code leads a brief overview of ways in which we can solve this problem.
pip install -r requirements.txt
python main.py ./data ./output/ring1 --epochs 200 --batch_size 100 --l2_lambda 0.01 --model_name resnet50 --dataset_name cifar10 --group_channels 32 --drop_rate 0.1 --last_drop_rate 0.5 --loglevel INFO --train_data_length 10000 --cuda_device_no 1 --skip_plots --scheduler StepLR --step_size 410 --gamma 0.9 --sleep_factor 0 DSGD node0 ./conf/node_list.json ./conf/hosts.json --lr 0.002 --swap_timeout 10 --async_step > ./output/ring1/node0.txt
python main.py ./data ./output/ring2 --epochs 200 --batch_size 100 --l2_lambda 0.01 --model_name resnet50 --dataset_name cifar10 --group_channels 32 --drop_rate 0.1 --last_drop_rate 0.5 --loglevel INFO --train_data_length 10000 --cuda_device_no 2 --skip_plots --scheduler StepLR --step_size 410 --gamma 0.9 --sleep_factor 0 DSGD node1 ./conf/node_list.json ./conf/hosts.json --lr 0.002 --swap_timeout 10 --async_step > ./output/ring2/node1.txt
python main.py ./data ./output/ring3 --epochs 200 --batch_size 100 --l2_lambda 0.01 --model_name resnet50 --dataset_name cifar10 --group_channels 32 --drop_rate 0.1 --last_drop_rate 0.5 --loglevel INFO --train_data_length 10000 --cuda_device_no 3 --skip_plots --scheduler StepLR --step_size 410 --gamma 0.9 --sleep_factor 0 DSGD node2 ./conf/node_list.json ./conf/hosts.json --lr 0.002 --swap_timeout 10 --async_step > ./output/ring3/node2.txt
optim_args = {}
if args.optimizer == "sgd":
optim_args["lr"] = args.lr
optim_args["momentum"] = args.momentum
optim_args["weight_decay"] = args.weight_decay
optim_args["dampening"] = args.dampening
optim_args["nesterov"] = args.nesterov
elif args.optimizer == "adam":
optim_args["lr"] = args.lr
optim_args["betas"] = args.betas
optim_args["eps"] = args.eps
optim_args["weight_decay"] = args.weight_decay
optim_args["amsgrad"] = args.amsgrad
elif args.optimizer == "DSGD":
optim_args["lr"] = args.lr
optim_args["momentum"] = args.momentum
optim_args["weight_decay"] = args.weight_decay
optim_args["dampening"] = args.dampening
optim_args["nesterov"] = args.nesterov
optim_args["weight"] = args.weight
optim_args["round_step"] = args.round_step
optim_args["swap_timeout"] = args.swap_timeout
elif args.optimizer == "GossipSGD":
optim_args["lr"] = args.lr
optim_args["momentum"] = args.momentum
optim_args["weight_decay"] = args.weight_decay
optim_args["dampening"] = args.dampening
optim_args["nesterov"] = args.nesterov
optim_args["weight"] = args.weight
optim_args["round_step"] = args.round_step
optim_args["swap_timeout"] = args.swap_timeout
elif args.optimizer == "AdmmSGD":
optim_args["lr"] = args.lr
optim_args["mu"] = args.mu
optim_args["eta"] = args.eta
optim_args["rho"] = args.rho
optim_args["drs"] = True
optim_args["round_step"] = args.round_step
optim_args["swap_timeout"] = args.swap_timeout
elif args.optimizer == "PdmmSGD":
optim_args["lr"] = args.lr
optim_args["mu"] = args.mu
optim_args["eta"] = args.eta
optim_args["rho"] = args.rho
optim_args["drs"] = False
optim_args["round_step"] = args.round_step
optim_args["swap_timeout"] = args.swap_timeout
elif args.optimizer == "PdmmISVR":
optim_args["lr"] = args.lr
optim_args["round_step"] = args.round_step
optim_args["use_gcoef"] = args.use_gcoef
optim_args["drs"] = False
optim_args["piw"] = args.piw
optim_args["swap_timeout"] = args.swap_timeout
elif args.optimizer == "AdmmISVR":
optim_args["lr"] = args.lr
optim_args["round_step"] = args.round_step
optim_args["use_gcoef"] = args.use_gcoef
optim_args["drs"] = True
optim_args["piw"] = args.piw
optim_args["swap_timeout"] = args.swap_timeout
Ex.) Basic optimizer codes
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
for t in ragne(epochs):
for batch, (X,y) in enumerate(dataloader):
pred = model(X)
loss = loss_fn(pred,y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
Ex.) Stochastic Gradient Descent(SGD)
# optimizer.step()
# W = W - lr * dl/dW
for group in param_groups:
for i, param in enumerate(group['params']):
param.add_(param.grad, alpha=-lr)
epoch, train_loss, val_loss, test_loss, train_acc, val_acc, test_acc
1, 2.616, 2.192, 2.191, 0.173, 0.178, 0.175
...
200, 0.311, 0.187, 0.998, 0.942, 0.939, [ 0.721 ]
Ex.) Decentralized Stochastic Gradient Descent(D-SGD)
for group in param_groups:
for i, param in enumerate(group['params']):
param.add_(param.grad, alpha=-lr)
############################################
param_sum = torch.zeros_like(param)
param_sum += param.data
for edge in edges:
param_sum += edge.prm_state["rcv"][i]
param.data = param_sum / edge_num
for edge in edges:
edge.prm_state["snd"][i] = param.data
###########################################
(Node-1)
epoch, train_loss, val_loss, test_loss, train_acc, val_acc, test_acc
1, 2.560, 2.417, 2.415, 0.142, 0.111, 0.118
...
200, 1.761, 1.617, 1.630, 0.465, 0.424, [ 0.416 ]
(Node-2)
epoch, train_loss, val_loss, test_loss, train_acc, val_acc, test_acc
1, 2.580, 2.349, 2.349, 0.186, 0.148, 0.145
...
200, 1.654, 1.624, 1.632, 0.462, 0.426, [ 0.421 ]
(Node-3)
epoch, train_loss, val_loss, test_loss, train_acc, val_acc, test_acc
1, 2.555, 2.389, 2.390, 0.186, 0.121, 0.119
...
100, 1.632, 1.687, 1.701, 0.490, 0.406, [ 0.405 ]
...
200,
Ex.) Gossip Stochastic Gradient Descent(Gossip SGD)
for edge in edges:
for group in param_groups:
for i, param in enumerate(group['params']):
############################################
param.data = torch.div( (param.data + edge.prm_state["rcv"][i] ), 2 )
edge.prm_state["snd"][i] = param.data
###########################################
(Node-1)
epoch, train_loss, val_loss, test_loss, train_acc, val_acc, test_acc
1, 2.574, 2.467, 2.467, 0.192, 0.120, 0.121
...
198, 1.489, 1.620, 1.637, 0.480, 0.435, [ 0.433 ]
(Node-2)
epoch, train_loss, val_loss, test_loss, train_acc, val_acc, test_acc
1, 2.643, 2.443, 2.444, 0.153, 0.107, 0.107
...
200, 1.423, 1.445, 1.463, 0.536, 0.489, [ 0.486 ]
(Node-3)
epoch, train_loss, val_loss, test_loss, train_acc, val_acc, test_acc
1, 2.546, 2.392, 2.392, 0.252, 0.160, 0.154
...
89, 1.438, 1.471, 1.500, 0.534, 0.485, [ 0.477 ]
...
200
(Node-1)
epoch, train_loss, val_loss, test_loss, train_acc, val_acc, test_acc
1, 2.770, 2.907, 2.906, 0.099, 0.100, 0.100
...
200,
(Node-2) train=0.842, val=0.777, test=0.962, l2=6752.562), acc: (train=0.704, val=0.735, test=0.674
epoch, train_loss, val_loss, test_loss, train_acc, val_acc, test_acc
1, 2.655, 2.931, 2.931, 0.123, 0.121, 0.122
...
200, 0.842, 0.815, 0.982, l 0.707, 0.718, [ 0.662 ]
(Node-3)
epoch, train_loss, val_loss, test_loss, train_acc, val_acc, test_acc
1, 2.860, 2.697, 2.697, 0.156, 0.100, 0.100
...
85, 1.560, 3.381, 3.381, 0.178, 0.129, [ 0.12 ]
...
200
(Node-1)
epoch, train_loss, val_loss, test_loss, train_acc, val_acc, test_acc
1, 2.315, 2.366, 2.364, 0.169, 0.118, 0.120
...
199, 0.992, 0.866, 1.050, 0.673, 0.701, [ 0.651 ]
(Node-2)
epoch, train_loss, val_loss, test_loss, train_acc, val_acc, test_acc
1, 2.311, 2.336, 2.338, 0.147, 0.105, 0.103
...
200, 0.886, 0.790, 0.991, 0.715, 0.726, [ 0.667 ]
(Node-3)
epoch, train_loss, val_loss, test_loss, train_acc, val_acc, test_acc
1, 2.290, 2.493, 2.494, 0.227, 0.144, 0.145
...
84, 0.857, 0.833, 1.026, 0.741, 0.713, [ 0.655 ]
...
200
(Node-1)
epoch, train_loss, val_loss, test_loss, train_acc, val_acc, test_acc
1, 2.512, 2.436, 2.431, 0.180, 0.118, 0.119
...
200, 1.655, 1.246, 1.300, 0.585, 0.580, [ 0.562 ]
(Node-2)
epoch, train_loss, val_loss, test_loss, train_acc, val_acc, test_acc
1, 2.419, 2.319, 2.320, 0.210, 0.157, 0.155
...
200, 0.891, 0.951, 1.125, 0.688, 0.671, [ 0.623 ]
(Node-3)
epoch, train_loss, val_loss, test_loss, train_acc, val_acc, test_acc
1, 2.463, 2.510, 2.511, 0.178, 0.113, 0.115
...
85, 1.776, 1.225, 1.278, 0.641, 0.562, [ 0.547 ]
...
200
(Node-1)
epoch, train_loss, val_loss, test_loss, train_acc, val_acc, test_acc
1, 2.590, 2.459, 2.456, 0.166, 0.114, 0.117
...
200, 1.431, 4.144, 4.163, 0.257, 0.201, [ 0.198 ]
(Node-2)
epoch, train_loss, val_loss, test_loss, train_acc, val_acc, test_acc
1, 2.486, 2.330, 2.330, 0.199, 0.145, 0.145
...
200, 0.865, 0.805, 1.026, 0.715, 0.722, [ 0.661 ]
(Node-3)
epoch, train_loss, val_loss, test_loss, train_acc, val_acc, test_acc
1, 2.446, 2.492, 2.494, 0.217, 0.138, 0.13
...
86, 1.110, 1.457, 1.547, 0.680, 0.539, [ 0.518 ]
...
200