In [1]:
# Load packages and set up default settings
%matplotlib inline
from matplotlib import pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import datetime
import pickle

import data_load
import APFL2

torch.set_printoptions(edgeitems=2, linewidth=75)
torch.manual_seed(123)

<torch._C.Generator at 0x1a73033ff50>

In [2]:
device = (torch.device('cuda') if torch.cuda.is_available()
          else torch.device('cpu'))
print(f"Training on device {device}.")

Training on device cuda.


In [3]:
# Initialize all parameters
lambda_global = 10

M = 10
n_samples = 100

obj = 'APFL2'
data_name = 'KMNIST'

L_prime = 1

alpha_list = [0.2] * M

L_w = (lambda_global + 0.2 ** 2) * L_prime / M
L_beta = ((1-0.2) ** 2 ) * L_prime / M

prob1 = L_w / (L_w + L_beta)
prob2 = 1 - prob1

eta = 1e-1

rho = 0.01

n_communs = 1000
repo_step = 100
sync_step = 5

w0 = [torch.zeros(10, 784).to(device), torch.zeros(10).to(device)]

beta0 = []
for m in range(M):
    beta0.append([torch.zeros(10, 784).to(device), torch.zeros(10).to(device)])
    
w_list0 = []
for m in range(M):
    w_list0.append([torch.zeros(10, 784).to(device), torch.zeros(10).to(device)])

beta_list0 = []
for m in range(M):
    beta_list0.append([torch.zeros(10, 784).to(device), torch.zeros(10).to(device)])

In [4]:
print("L_w is: {}".format(L_w))
print("L_beta is {}".format(L_beta))
print("prob1 is {}".format(prob1))
print("prob2 is {}".format(prob2))
print("eta is {}".format(eta))
print("rho is {}".format(rho))
print("n_communs is {}".format(n_communs))
print("sync_step is {}".format(sync_step))

L_w is: 1.004
L_beta is 0.06400000000000002
prob1 is 0.9400749063670412
prob2 is 0.05992509363295884
eta is 0.1
rho is 0.01
n_communs is 1000
sync_step is 5


In [5]:
train_loader_list, devices_train_list = data_load.data_prepare(data_name, n_devices=M, n_samples=n_samples)

In [7]:
loss_APFL2_KMNIST_LocalSGD, _, _ = APFL2.train_local_sgd(w_list0, beta_list0, alpha_list, 
                                                        sync_step, n_communs, devices_train_list, 
                                                        train_loader_list, lambda_global, repo_step, 
                                                        eta, obj, data_name)

num_commun: 1, loss: 25.328432083129883, time pass: 0s | Local_SGD APFL2 KMNIST
num_commun: 100, loss: 24.651032924652100, time pass: 20s | Local_SGD APFL2 KMNIST
num_commun: 200, loss: 24.079882621765137, time pass: 41s | Local_SGD APFL2 KMNIST
num_commun: 300, loss: 23.584928798675538, time pass: 60s | Local_SGD APFL2 KMNIST
num_commun: 400, loss: 23.176496315002442, time pass: 80s | Local_SGD APFL2 KMNIST
num_commun: 500, loss: 22.829382705688477, time pass: 102s | Local_SGD APFL2 KMNIST
num_commun: 600, loss: 22.529628467559814, time pass: 124s | Local_SGD APFL2 KMNIST
num_commun: 700, loss: 22.265361309051514, time pass: 146s | Local_SGD APFL2 KMNIST
num_commun: 800, loss: 22.030208873748780, time pass: 169s | Local_SGD APFL2 KMNIST
num_commun: 900, loss: 21.819151020050050, time pass: 193s | Local_SGD APFL2 KMNIST
num_commun: 1000, loss: 21.624181938171386, time pass: 218s | Local_SGD APFL2 KMNIST


In [8]:
loss_APFL2_KMNIST_CD, _, _ = APFL2.train_CD(w0, beta0, n_communs, devices_train_list, 
                                           train_loader_list, lambda_global, repo_step, 
                                           eta, prob1, alpha_list, obj, data_name)

num_commun: 1, loss: 25.2818299294, time pass: 0s | CD APFL2 KMNIST
num_commun: 100, loss: 22.8495219231, time pass: 8s | CD APFL2 KMNIST
num_commun: 200, loss: 21.4707583427, time pass: 16s | CD APFL2 KMNIST
num_commun: 300, loss: 20.1008598804, time pass: 23s | CD APFL2 KMNIST
num_commun: 400, loss: 18.9846212387, time pass: 31s | CD APFL2 KMNIST
num_commun: 500, loss: 18.1402508259, time pass: 39s | CD APFL2 KMNIST
num_commun: 600, loss: 17.2794483662, time pass: 48s | CD APFL2 KMNIST
num_commun: 700, loss: 16.5850193977, time pass: 56s | CD APFL2 KMNIST
num_commun: 800, loss: 15.8915137768, time pass: 66s | CD APFL2 KMNIST
num_commun: 900, loss: 15.3278728962, time pass: 75s | CD APFL2 KMNIST
num_commun: 1000, loss: 14.7768237591, time pass: 84s | CD APFL2 KMNIST


In [9]:
loss_APFL2_KMNIST_CDVR, _, _ = APFL2.train_CDVR(w0, beta0, n_communs, devices_train_list, 
                                               train_loader_list, lambda_global, repo_step, 
                                               eta, prob1, rho, alpha_list, obj, data_name)

num_commun: 1, loss: 25.2866430283, time pass: 0s | CDVR APFL2 KMNIST
num_commun: 100, loss: 22.7015739441, time pass: 11s | CDVR APFL2 KMNIST
num_commun: 200, loss: 21.1422868252, time pass: 22s | CDVR APFL2 KMNIST
num_commun: 300, loss: 19.9692090034, time pass: 35s | CDVR APFL2 KMNIST
num_commun: 400, loss: 18.7814381123, time pass: 46s | CDVR APFL2 KMNIST
num_commun: 500, loss: 17.8463513851, time pass: 58s | CDVR APFL2 KMNIST
num_commun: 600, loss: 17.0468186378, time pass: 71s | CDVR APFL2 KMNIST
num_commun: 700, loss: 16.3083623886, time pass: 84s | CDVR APFL2 KMNIST
num_commun: 800, loss: 15.6179485798, time pass: 96s | CDVR APFL2 KMNIST
num_commun: 900, loss: 15.0125411510, time pass: 106s | CDVR APFL2 KMNIST
num_commun: 1000, loss: 14.4741307259, time pass: 118s | CDVR APFL2 KMNIST


In [10]:
with open("./result/loss_APFL2_KMNIST_CD.txt", "wb") as f:   #Pickling
    pickle.dump(loss_APFL2_KMNIST_CD, f)
    
with open("./result/loss_APFL2_KMNIST_CDVR.txt", "wb") as f:   #Pickling
    pickle.dump(loss_APFL2_KMNIST_CDVR, f)
    
with open("./result/loss_APFL2_KMNIST_LocalSGD.txt", "wb") as f:   #Pickling
    pickle.dump(loss_APFL2_KMNIST_LocalSGD, f)