In [1]:
# Load packages and set up default settings
%matplotlib inline
from matplotlib import pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import datetime
import pickle

import data_load
import MX2

torch.set_printoptions(edgeitems=2, linewidth=75)
torch.manual_seed(123)

<torch._C.Generator at 0x2424f4f1f50>

In [2]:
device = (torch.device('cuda') if torch.cuda.is_available()
          else torch.device('cpu'))
print(f"Training on device {device}.")

Training on device cuda.


In [4]:
# Initialize all parameters
lambda_penal = 1

M = 10
n_samples = 100
    
obj = 'MX2'
data_name = 'FMNIST'

L_prime = 1

L_w = lambda_penal / M
L_beta = 1 + lambda_penal

prob1 = L_w / (L_w + L_beta)
prob2 = 1 - prob1

eta = 1e-1

rho = 1e-2

n_communs = 1000
repo_step = 100
sync_step = 5

w0 = [torch.zeros(10, 784).to(device), torch.zeros(10).to(device)]

beta0 = []
for m in range(M):
    beta0.append([torch.zeros(10, 784).to(device), torch.zeros(10).to(device)])
    
w_list0 = []
for m in range(M):
    w_list0.append([torch.zeros(10, 784).to(device), torch.zeros(10).to(device)])

beta_list0 = []
for m in range(M):
    beta_list0.append([torch.zeros(10, 784).to(device), torch.zeros(10).to(device)])

In [5]:
print("L_w is: {}".format(L_w))
print("L_beta is {}".format(L_beta))
print("prob1 is {}".format(prob1))
print("prob2 is {}".format(prob2))
print("rho is {}".format(rho))
print("n_communs is {}".format(n_communs))
print("sync_step is {}".format(sync_step))

L_w is: 0.1
L_beta is 0.2
prob1 is 0.3333333333333333
prob2 is 0.6666666666666667
rho is 0.01
n_communs is 1000
sync_step is 5


In [5]:
train_loader_list, devices_train_list = data_load.data_prepare(data_name, n_devices=M, n_samples=n_samples)

In [7]:
loss_MX2_FMNIST_LocalSGD, _, _ = MX2.train_local_sgd(w_list0, beta_list0, sync_step, 
                                                    n_communs, devices_train_list, train_loader_list, 
                                                    lambda_penal, repo_step, eta, obj, data_name)

num_commun: 1, loss: 2.302584767341614, time pass: 0s | Local_SGD MX2 FMNIST
num_commun: 100, loss: 2.022634959220886, time pass: 19s | Local_SGD MX2 FMNIST
num_commun: 200, loss: 1.992257821559906, time pass: 38s | Local_SGD MX2 FMNIST
num_commun: 300, loss: 1.986917591094971, time pass: 57s | Local_SGD MX2 FMNIST
num_commun: 400, loss: 1.983015775680542, time pass: 76s | Local_SGD MX2 FMNIST
num_commun: 500, loss: 1.979705470800400, time pass: 95s | Local_SGD MX2 FMNIST
num_commun: 600, loss: 1.975946146249771, time pass: 115s | Local_SGD MX2 FMNIST
num_commun: 700, loss: 1.972659128904343, time pass: 136s | Local_SGD MX2 FMNIST
num_commun: 800, loss: 1.969220441579819, time pass: 159s | Local_SGD MX2 FMNIST
num_commun: 900, loss: 1.965963608026505, time pass: 181s | Local_SGD MX2 FMNIST
num_commun: 1000, loss: 1.962999510765076, time pass: 204s | Local_SGD MX2 FMNIST


In [8]:
loss_MX2_FMNIST_CD, _, _ = MX2.train_CD(w0, beta0, n_communs, devices_train_list, 
                                 train_loader_list, lambda_penal, repo_step, eta, prob1, obj, data_name)

num_commun: 1, loss: 2.3025847673, time pass: 0s | CD MX2 FMNIST
num_commun: 100, loss: 1.9640193820, time pass: 10s | CD MX2 FMNIST
num_commun: 200, loss: 1.9120150566, time pass: 22s | CD MX2 FMNIST
num_commun: 300, loss: 1.8687236905, time pass: 36s | CD MX2 FMNIST
num_commun: 400, loss: 1.8334607124, time pass: 51s | CD MX2 FMNIST
num_commun: 500, loss: 1.8018663585, time pass: 66s | CD MX2 FMNIST
num_commun: 600, loss: 1.7719134033, time pass: 85s | CD MX2 FMNIST
num_commun: 700, loss: 1.7447489738, time pass: 103s | CD MX2 FMNIST
num_commun: 800, loss: 1.7203852594, time pass: 118s | CD MX2 FMNIST
num_commun: 900, loss: 1.6997707844, time pass: 131s | CD MX2 FMNIST
num_commun: 1000, loss: 1.6771988332, time pass: 145s | CD MX2 FMNIST


In [9]:
loss_MX2_FMNIST_CDVR, _, _ = MX2.train_CDVR(w0, beta0, n_communs, devices_train_list, 
                                           train_loader_list, lambda_penal, repo_step, 
                                           eta, prob1, obj, data_name, rho)

num_commun: 1, loss: 2.2954288363, time pass: 0s | CDVR MX2 FMNIST
num_commun: 100, loss: 2.0102390528, time pass: 24s | CDVR MX2 FMNIST
num_commun: 200, loss: 1.9165530562, time pass: 47s | CDVR MX2 FMNIST
num_commun: 300, loss: 1.8791693568, time pass: 69s | CDVR MX2 FMNIST
num_commun: 400, loss: 1.8324231684, time pass: 93s | CDVR MX2 FMNIST
num_commun: 500, loss: 1.7970671237, time pass: 118s | CDVR MX2 FMNIST
num_commun: 600, loss: 1.7664565623, time pass: 142s | CDVR MX2 FMNIST
num_commun: 700, loss: 1.7416998148, time pass: 167s | CDVR MX2 FMNIST
num_commun: 800, loss: 1.7169719100, time pass: 192s | CDVR MX2 FMNIST
num_commun: 900, loss: 1.6928803205, time pass: 214s | CDVR MX2 FMNIST
num_commun: 1000, loss: 1.6739354193, time pass: 229s | CDVR MX2 FMNIST


In [10]:
with open("./result/loss_MX2_FMNIST_CD.txt", "wb") as f:   #Pickling
    pickle.dump(loss_MX2_FMNIST_CD, f)
    
with open("./result/loss_MX2_FMNIST_CDVR.txt", "wb") as f:   #Pickling
    pickle.dump(loss_MX2_FMNIST_CDVR, f)
    
with open("./result/loss_MX2_FMNIST_LocalSGD.txt", "wb") as f:   #Pickling
    pickle.dump(loss_MX2_FMNIST_LocalSGD, f)