In [1]:
# Load packages and set up default settings
%matplotlib inline
from matplotlib import pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import datetime
import pickle

import data_load
import MT2

torch.set_printoptions(edgeitems=2, linewidth=75)
torch.manual_seed(123)

<torch._C.Generator at 0x2230793ff50>

In [2]:
device = (torch.device('cuda') if torch.cuda.is_available()
          else torch.device('cpu'))
print(f"Training on device {device}.")

Training on device cuda.


In [3]:
# Initialize all parameters
lambda_global = 10
lambda_penal = 1

M = 10
n_samples = 100

obj = 'MT2'
data_name = 'MNIST'

L_w = (lambda_global + lambda_penal) / M
L_beta = 1 + lambda_penal

prob1 = L_w / (L_w + L_beta)
prob2 = 1 - prob1

eta = 1e-1

rho = 1e-2

n_communs = 1000
repo_step = 100
sync_step = 5

w0 = [torch.zeros(10, 784).to(device), torch.zeros(10).to(device)]

beta0 = []
for m in range(M):
    beta0.append([torch.zeros(10, 784).to(device), torch.zeros(10).to(device)])
    
w_list0 = []
for m in range(M):
    w_list0.append([torch.zeros(10, 784).to(device), torch.zeros(10).to(device)])

beta_list0 = []
for m in range(M):
    beta_list0.append([torch.zeros(10, 784).to(device), torch.zeros(10).to(device)])

In [4]:
print("L_w is: {}".format(L_w))
print("L_beta is {}".format(L_beta))
print("prob1 is {}".format(prob1))
print("prob2 is {}".format(prob2))
print("eta is {}".format(eta))
print("rho is {}".format(rho))
print("n_communs is {}".format(n_communs))
print("sync_step is {}".format(sync_step))

L_w is: 1.1
L_beta is 0.2
prob1 is 0.8461538461538461
prob2 is 0.15384615384615385
eta is 0.1
rho is 0.01
n_communs is 1000
sync_step is 5


In [5]:
train_loader_list, devices_train_list = data_load.data_prepare(data_name, n_devices=M, n_samples=n_samples)

In [7]:
loss_MT2_MNIST_LocalSGD, _, _ = MT2.train_local_sgd(w_list0, beta_list0, sync_step, n_communs, devices_train_list, 
                                                     train_loader_list, lambda_global, lambda_penal, 
                                                     repo_step, eta, obj, data_name)

num_commun: 1, loss: 25.3284320831, time pass: 0s | Local_SGD MT2 MNIST
num_commun: 100, loss: 24.6433942795, time pass: 17s | Local_SGD MT2 MNIST
num_commun: 200, loss: 24.2255694389, time pass: 34s | Local_SGD MT2 MNIST
num_commun: 300, loss: 23.8831072807, time pass: 49s | Local_SGD MT2 MNIST
num_commun: 400, loss: 23.5901333809, time pass: 65s | Local_SGD MT2 MNIST
num_commun: 500, loss: 23.3140561104, time pass: 82s | Local_SGD MT2 MNIST
num_commun: 600, loss: 23.0625981331, time pass: 98s | Local_SGD MT2 MNIST
num_commun: 700, loss: 22.8331655502, time pass: 113s | Local_SGD MT2 MNIST
num_commun: 800, loss: 22.6161151886, time pass: 130s | Local_SGD MT2 MNIST
num_commun: 900, loss: 22.4090280533, time pass: 145s | Local_SGD MT2 MNIST
num_commun: 1000, loss: 22.2120543480, time pass: 160s | Local_SGD MT2 MNIST


In [8]:
loss_MT2_MNIST_CD, _, _ = MT2.train_CD(w0, beta0, n_communs, devices_train_list, 
                                       train_loader_list, lambda_global, lambda_penal, 
                                       repo_step, eta, prob1, obj, data_name)

num_commun: 1, loss: 25.2781715393, time pass: 0s | CD MT2 MNIST
num_commun: 100, loss: 23.1097467422, time pass: 4s | CD MT2 MNIST
num_commun: 200, loss: 21.8042819023, time pass: 9s | CD MT2 MNIST
num_commun: 300, loss: 20.7007495880, time pass: 14s | CD MT2 MNIST
num_commun: 400, loss: 19.7506006718, time pass: 19s | CD MT2 MNIST
num_commun: 500, loss: 18.7554688931, time pass: 24s | CD MT2 MNIST
num_commun: 600, loss: 17.9864258766, time pass: 29s | CD MT2 MNIST
num_commun: 700, loss: 17.1496109009, time pass: 34s | CD MT2 MNIST
num_commun: 800, loss: 16.4925632477, time pass: 39s | CD MT2 MNIST
num_commun: 900, loss: 15.7995675087, time pass: 44s | CD MT2 MNIST
num_commun: 1000, loss: 15.2127392769, time pass: 49s | CD MT2 MNIST


In [10]:
loss_MT2_MNIST_CDVR, _, _ = MT2.train_CDVR(w0, beta0, n_communs, devices_train_list, 
                                           train_loader_list, lambda_global, lambda_penal, 
                                           repo_step, eta, prob1, rho, obj, data_name)

num_commun: 1, loss: 25.2772511482, time pass: 0s | CDVR MT2 MNIST
num_commun: 100, loss: 22.8898561478, time pass: 7s | CDVR MT2 MNIST
num_commun: 200, loss: 21.9230146408, time pass: 14s | CDVR MT2 MNIST
num_commun: 300, loss: 20.6447943687, time pass: 21s | CDVR MT2 MNIST
num_commun: 400, loss: 19.6485673904, time pass: 29s | CDVR MT2 MNIST
num_commun: 500, loss: 18.8457640648, time pass: 36s | CDVR MT2 MNIST
num_commun: 600, loss: 18.0835960865, time pass: 43s | CDVR MT2 MNIST
num_commun: 700, loss: 17.1905948162, time pass: 50s | CDVR MT2 MNIST
num_commun: 800, loss: 16.4883113861, time pass: 58s | CDVR MT2 MNIST
num_commun: 900, loss: 15.7983124733, time pass: 65s | CDVR MT2 MNIST
num_commun: 1000, loss: 15.1837946415, time pass: 73s | CDVR MT2 MNIST


In [12]:
with open("./result/loss_MT2_MNIST_CD.txt", "wb") as f:   #Pickling
    pickle.dump(loss_MT2_MNIST_CD, f)
    
with open("./result/loss_MT2_MNIST_CDVR.txt", "wb") as f:   #Pickling
    pickle.dump(loss_MT2_MNIST_CDVR, f)
    
with open("./result/loss_MT2_MNIST_LocalSGD.txt", "wb") as f:   #Pickling
    pickle.dump(loss_MT2_MNIST_LocalSGD, f)