In [1]:
# Load packages and set up default settings
%matplotlib inline
from matplotlib import pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import datetime
import pickle

import data_load
import MX2

torch.set_printoptions(edgeitems=2, linewidth=75)
torch.manual_seed(123)

<torch._C.Generator at 0x187983f0f30>

In [2]:
device = (torch.device('cuda') if torch.cuda.is_available()
          else torch.device('cpu'))
print(f"Training on device {device}.")

Training on device cuda.


In [3]:
# Initialize all parameters
lambda_penal = 1

M = 10
n_samples = 100
    
obj = 'MX2'
data_name = 'MNIST'

L_w = lambda_penal / M
L_beta = 1 + lambda_penal

rho = 1e-2

n_communs = 1000
n_epochs = 1000
repo_step = 100
sync_step = 5

w0 = [torch.zeros(10, 784).to(device), torch.zeros(10).to(device)]

beta0 = []
for m in range(M):
    beta0.append([torch.zeros(10, 784).to(device), torch.zeros(10).to(device)])
    
w_list0 = []
for m in range(M):
    w_list0.append([torch.zeros(10, 784).to(device), torch.zeros(10).to(device)])

beta_list0 = []
for m in range(M):
    beta_list0.append([torch.zeros(10, 784).to(device), torch.zeros(10).to(device)])

In [4]:
print("L_w is: {}".format(L_w))
print("L_beta is {}".format(L_beta))
print("rho is {}".format(rho))
print("n_communs is {}".format(n_communs))
print("n_epochs is {}".format(n_epochs))
print("sync_step is {}".format(sync_step))

L_w is: 0.1
L_beta is 2
rho is 0.01
n_communs is 1000
n_epochs is 1000
sync_step is 5


In [5]:
train_loader_list, devices_train_list = data_load.data_prepare(data_name, n_devices=M, n_samples=n_samples)

In [6]:
n_sizes = []
for device_train in devices_train_list:
    n_sizes.append(len(device_train))
print("Max sample size: {}".format(max(n_sizes)))
print("Min sample size: {}".format(min(n_sizes)))
print("Mean sample size: {}".format(int(np.mean(n_sizes))))

Max sample size: 100
Min sample size: 100
Mean sample size: 100


In [7]:
prob1_list = [L_w / (L_w + L_beta), 0.7, 0.5, 0.3, 0.1]
MX2_MNIST_CDVR_iter_result = []
MX2_MNIST_CDVR_commun_result = []

for prob1 in prob1_list:
    print("prob1 is {:.5f}".format(prob1))
    
    eta = 1 / (8 * max(L_w/prob1, L_beta/(1-prob1)))
    
    print("eta is {:.5f}".format(eta))
    
    loss_MX2_MNIST_CDVR_iter, _, _ = MX2.train_CDVR_iter(w0, beta0, n_communs, devices_train_list, 
                                                         train_loader_list, lambda_penal, repo_step, 
                                                         eta, prob1, obj, data_name, rho)
    MX2_MNIST_CDVR_iter_result.append(loss_MX2_MNIST_CDVR_iter)
    loss_MX2_MNIST_CDVR_commun, _, _ = MX2.train_CDVR_commun(w0, beta0, n_epochs, devices_train_list, 
                                                             train_loader_list, lambda_penal, repo_step, 
                                                             eta, prob1, obj, data_name, rho)
    MX2_MNIST_CDVR_commun_result.append(loss_MX2_MNIST_CDVR_commun)

prob1 is 0.04762
eta is 0.05952
epoch: 1, loss: 2.2986262321, time pass: 0s | CDVR MX2 MNIST
epoch: 100, loss: 2.0801148057, time pass: 8s | CDVR MX2 MNIST
epoch: 200, loss: 2.0020480990, time pass: 17s | CDVR MX2 MNIST
epoch: 300, loss: 1.9988469481, time pass: 28s | CDVR MX2 MNIST
epoch: 400, loss: 1.9850868583, time pass: 40s | CDVR MX2 MNIST
epoch: 500, loss: 1.9723028183, time pass: 52s | CDVR MX2 MNIST
epoch: 600, loss: 1.9636449635, time pass: 63s | CDVR MX2 MNIST
epoch: 700, loss: 1.9565222144, time pass: 75s | CDVR MX2 MNIST
epoch: 800, loss: 1.9484544396, time pass: 87s | CDVR MX2 MNIST
epoch: 900, loss: 1.9404346228, time pass: 99s | CDVR MX2 MNIST
epoch: 1000, loss: 1.9313554287, time pass: 111s | CDVR MX2 MNIST
num_commun: 1, loss: 2.2986262321, time pass: 0s | CDVR MX2 MNIST
num_commun: 100, loss: 1.8614854813, time pass: 143s | CDVR MX2 MNIST
num_commun: 200, loss: 1.7643390954, time pass: 289s | CDVR MX2 MNIST
num_commun: 300, loss: 1.6959381938, time pass: 419s | CDVR 

In [11]:
with open("./result/MX2_MNIST_CDVR_iter_result.txt", "wb") as f:   #Pickling
    pickle.dump(MX2_MNIST_CDVR_iter_result, f)
    
with open("./result/MX2_MNIST_CDVR_commun_result.txt", "wb") as f:   #Pickling
    pickle.dump(MX2_MNIST_CDVR_commun_result, f)