In [1]:
# Load packages and set up default settings
%matplotlib inline
from matplotlib import pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import datetime
import pickle

import data_load
import APFL2

torch.set_printoptions(edgeitems=2, linewidth=75)
torch.manual_seed(123)

<torch._C.Generator at 0x18451851f30>

In [2]:
device = (torch.device('cuda') if torch.cuda.is_available()
          else torch.device('cpu'))
print(f"Training on device {device}.")

Training on device cuda.


In [3]:
# Initialize all parameters
lambda_global = 10

M = 10
n_samples = 100

obj = 'APFL2'
data_name = 'KMNIST'

alpha_list = [0.2] * M

L_w = (lambda_global + 0.2 ** 2) * 1 / M
L_beta = ((1-0.2) ** 2 ) * 1 / M

rho = 0.01

n_communs = 1000
n_epochs = 1000
repo_step = 100
sync_step = 5

w0 = [torch.zeros(10, 784).to(device), torch.zeros(10).to(device)]

beta0 = []
for m in range(M):
    beta0.append([torch.zeros(10, 784).to(device), torch.zeros(10).to(device)])
    
w_list0 = []
for m in range(M):
    w_list0.append([torch.zeros(10, 784).to(device), torch.zeros(10).to(device)])

beta_list0 = []
for m in range(M):
    beta_list0.append([torch.zeros(10, 784).to(device), torch.zeros(10).to(device)])

In [4]:
print("L_w is: {}".format(L_w))
print("L_beta is {}".format(L_beta))
print("rho is {}".format(rho))
print("n_communs is {}".format(n_communs))
print("n_epochs is {}".format(n_epochs))
print("sync_step is {}".format(sync_step))

L_w is: 1.004
L_beta is 0.06400000000000002
rho is 0.01
n_communs is 1000
n_epochs is 1000
sync_step is 5


In [5]:
train_loader_list, devices_train_list = data_load.data_prepare(data_name, n_devices=M, n_samples=n_samples)

In [7]:
prob1_list = [L_w / (L_w + L_beta), 0.7, 0.5, 0.3, 0.1]
APFL2_KMNIST_CDVR_iter_result = []
APFL2_KMNIST_CDVR_commun_result = []

for prob1 in prob1_list:
    print("prob1 is {:.5f}".format(prob1))
    
    eta = 1 / (8 * max(L_w/prob1, L_beta/(1-prob1)))
    
    print("eta is {:.5f}".format(eta))
    
    loss_APFL2_KMNIST_CDVR_iter, _, _ = APFL2.train_CDVR_iter(w0, beta0,n_communs, devices_train_list, 
                                                             train_loader_list, lambda_global, repo_step, 
                                                             eta, prob1, rho, alpha_list, obj, data_name)
    APFL2_KMNIST_CDVR_iter_result.append(loss_APFL2_KMNIST_CDVR_iter)
    loss_APFL2_KMNIST_CDVR_commun, _, _ = APFL2.train_CDVR_commun(w0, beta0,n_communs, devices_train_list, 
                                                                 train_loader_list, lambda_global, repo_step, 
                                                                 eta, prob1, rho, alpha_list, obj, data_name)
    APFL2_KMNIST_CDVR_commun_result.append(loss_APFL2_KMNIST_CDVR_commun)
    print(len(loss_APFL2_KMNIST_CDVR_commun))

prob1 is 0.94007
eta is 0.11704
epoch: 1, loss: 25.2795491219, time pass: 0s | CDVR APFL2 KMNIST
epoch: 100, loss: 22.4642164230, time pass: 8s | CDVR APFL2 KMNIST
epoch: 200, loss: 20.8371619225, time pass: 17s | CDVR APFL2 KMNIST
epoch: 300, loss: 19.6392883778, time pass: 27s | CDVR APFL2 KMNIST
epoch: 400, loss: 18.3831287384, time pass: 38s | CDVR APFL2 KMNIST
epoch: 500, loss: 17.3925869465, time pass: 50s | CDVR APFL2 KMNIST
epoch: 600, loss: 16.5547383785, time pass: 62s | CDVR APFL2 KMNIST
epoch: 700, loss: 15.8117887497, time pass: 74s | CDVR APFL2 KMNIST
epoch: 800, loss: 15.1104201317, time pass: 86s | CDVR APFL2 KMNIST
epoch: 900, loss: 14.4991574287, time pass: 99s | CDVR APFL2 KMNIST
epoch: 1000, loss: 13.9647968292, time pass: 111s | CDVR APFL2 KMNIST
num_commun: 1, loss: 25.2795491219, time pass: 0s | CDVR APFL2 KMNIST
num_commun: 100, loss: 22.4090167046, time pass: 12s | CDVR APFL2 KMNIST
num_commun: 200, loss: 20.7464145660, time pass: 25s | CDVR APFL2 KMNIST
num_co

In [8]:
with open("./result/APFL2_KMNIST_CDVR_iter_result.txt", "wb") as f:   #Pickling
    pickle.dump(APFL2_KMNIST_CDVR_iter_result, f)
    
with open("./result/APFL2_KMNIST_CDVR_commun_result.txt", "wb") as f:   #Pickling
    pickle.dump(APFL2_KMNIST_CDVR_commun_result, f)