In [1]:
# Load packages and set up default settings
%matplotlib inline
from matplotlib import pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import datetime
import pickle

import data_load
import APFL2

torch.set_printoptions(edgeitems=2, linewidth=75)
torch.manual_seed(123)

<torch._C.Generator at 0x2127e401f30>

In [2]:
device = (torch.device('cuda') if torch.cuda.is_available()
          else torch.device('cpu'))
print(f"Training on device {device}.")

Training on device cuda.


In [3]:
# Initialize all parameters
lambda_global = 10

M = 10
n_samples = 100

obj = 'APFL2'
data_name = 'MNIST'

alpha_list = [0.2] * M

L_w = (lambda_global + 0.2 ** 2) * 1 / M
L_beta = ((1-0.2) ** 2 ) * 1 / M

rho = 0.01

n_communs = 1000
n_epochs = 1000
repo_step = 100
sync_step = 5

w0 = [torch.zeros(10, 784).to(device), torch.zeros(10).to(device)]

beta0 = []
for m in range(M):
    beta0.append([torch.zeros(10, 784).to(device), torch.zeros(10).to(device)])
    
w_list0 = []
for m in range(M):
    w_list0.append([torch.zeros(10, 784).to(device), torch.zeros(10).to(device)])

beta_list0 = []
for m in range(M):
    beta_list0.append([torch.zeros(10, 784).to(device), torch.zeros(10).to(device)])

In [4]:
print("L_w is: {}".format(L_w))
print("L_beta is {}".format(L_beta))
print("rho is {}".format(rho))
print("n_communs is {}".format(n_communs))
print("n_epochs is {}".format(n_epochs))
print("sync_step is {}".format(sync_step))

L_w is: 1.004
L_beta is 0.06400000000000002
rho is 0.01
n_communs is 1000
n_epochs is 1000
sync_step is 5


In [5]:
train_loader_list, devices_train_list = data_load.data_prepare(data_name, n_devices=M, n_samples=n_samples)

In [7]:
prob1_list = [L_w / (L_w + L_beta), 0.7, 0.5, 0.3, 0.1]
APFL2_MNIST_CDVR_iter_result = []
APFL2_MNIST_CDVR_commun_result = []

for prob1 in prob1_list:
    print("prob1 is {:.5f}".format(prob1))
    
    eta = 1 / (8 * max(L_w/prob1, L_beta/(1-prob1)))
    
    print("eta is {:.5f}".format(eta))
    
    loss_APFL2_MNIST_CDVR_iter, _, _ = APFL2.train_CDVR_iter(w0, beta0,n_communs, devices_train_list, 
                                                             train_loader_list, lambda_global, repo_step, 
                                                             eta, prob1, rho, alpha_list, obj, data_name)
    APFL2_MNIST_CDVR_iter_result.append(loss_APFL2_MNIST_CDVR_iter)
    loss_APFL2_MNIST_CDVR_commun, _, _ = APFL2.train_CDVR_commun(w0, beta0,n_communs, devices_train_list, 
                                                                 train_loader_list, lambda_global, repo_step, 
                                                                 eta, prob1, rho, alpha_list, obj, data_name)
    APFL2_MNIST_CDVR_commun_result.append(loss_APFL2_MNIST_CDVR_commun)

prob1 is 0.94007
eta is 0.11704
epoch: 1, loss: 25.2690143585, time pass: 0s | CDVR APFL2 MNIST
epoch: 100, loss: 21.8484075546, time pass: 8s | CDVR APFL2 MNIST
epoch: 200, loss: 19.7768568039, time pass: 16s | CDVR APFL2 MNIST
epoch: 300, loss: 18.1601456165, time pass: 26s | CDVR APFL2 MNIST
epoch: 400, loss: 16.4941403866, time pass: 36s | CDVR APFL2 MNIST
epoch: 500, loss: 15.2231342316, time pass: 47s | CDVR APFL2 MNIST
epoch: 600, loss: 14.1650820732, time pass: 59s | CDVR APFL2 MNIST
epoch: 700, loss: 13.2390500069, time pass: 71s | CDVR APFL2 MNIST
epoch: 800, loss: 12.3740795135, time pass: 84s | CDVR APFL2 MNIST
epoch: 900, loss: 11.6393677711, time pass: 96s | CDVR APFL2 MNIST
epoch: 1000, loss: 11.0157681227, time pass: 108s | CDVR APFL2 MNIST
num_commun: 1, loss: 25.2690143585, time pass: 0s | CDVR APFL2 MNIST
num_commun: 100, loss: 21.7755082130, time pass: 12s | CDVR APFL2 MNIST
num_commun: 200, loss: 19.6542203903, time pass: 25s | CDVR APFL2 MNIST
num_commun: 300, los

In [9]:
with open("./result/APFL2_MNIST_CDVR_iter_result.txt", "wb") as f:   #Pickling
    pickle.dump(APFL2_MNIST_CDVR_iter_result, f)
    
with open("./result/APFL2_MNIST_CDVR_commun_result.txt", "wb") as f:   #Pickling
    pickle.dump(APFL2_MNIST_CDVR_commun_result, f)