In [1]:
# Load packages and set up default settings
%matplotlib inline
from matplotlib import pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import datetime
import pickle

import data_load
import MT2

torch.set_printoptions(edgeitems=2, linewidth=75)
torch.manual_seed(123)

<torch._C.Generator at 0x1ff24eb2f30>

In [2]:
device = (torch.device('cuda') if torch.cuda.is_available()
          else torch.device('cpu'))
print(f"Training on device {device}.")

Training on device cuda.


In [3]:
# Initialize all parameters
lambda_global = 10
lambda_penal = 1

M = 10
n_samples = 100

obj = 'MT2'
data_name = 'KMNIST'

L_w = (lambda_global + lambda_penal) / M
L_beta = (1 + lambda_penal)

rho = 1e-2

n_communs = 1000
repo_step = 100
sync_step = 5

w0 = [torch.zeros(10, 784).to(device), torch.zeros(10).to(device)]

beta0 = []
for m in range(M):
    beta0.append([torch.zeros(10, 784).to(device), torch.zeros(10).to(device)])
    
w_list0 = []
for m in range(M):
    w_list0.append([torch.zeros(10, 784).to(device), torch.zeros(10).to(device)])

beta_list0 = []
for m in range(M):
    beta_list0.append([torch.zeros(10, 784).to(device), torch.zeros(10).to(device)])

In [4]:
print("L_w is: {}".format(L_w))
print("L_beta is {}".format(L_beta))
print("rho is {}".format(rho))
print("n_communs is {}".format(n_communs))
print("sync_step is {}".format(sync_step))

L_w is: 1.1
L_beta is 2
prob1 is 0.3548387096774194
prob2 is 0.6451612903225806
eta is 0.1
rho is 0.01
n_communs is 1000
sync_step is 5


In [5]:
train_loader_list, devices_train_list = data_load.data_prepare(data_name, n_devices=M, n_samples=n_samples)

In [7]:
prob1_list = [L_w / (L_w + L_beta), 0.7, 0.5, 0.3, 0.1]
MT2_KMNIST_CDVR_iter_result = []
MT2_KMNIST_CDVR_commun_result = []

for prob1 in prob1_list:
    print("prob1 is {:.5f}".format(prob1))
    
    eta = 1 / (8 * max(L_w/prob1, L_beta/(1-prob1)))
    
    print("eta is {:.5f}".format(eta))
    
    loss_MT2_KMNIST_CDVR_iter, _, _ = MT2.train_CDVR_iter(w0, beta0, n_epochs, devices_train_list, 
                                                         train_loader_list, lambda_global, lambda_penal, 
                                                         repo_step, eta, prob1, rho, obj, data_name)
    MT2_KMNIST_CDVR_iter_result.append(loss_MT2_KMNIST_CDVR_iter)
    loss_MT2_KMNIST_CDVR_commun, _, _ = MT2.train_CDVR_commun(w0, beta0, n_communs, devices_train_list, 
                                                             train_loader_list, lambda_global, lambda_penal, 
                                                             repo_step, eta, prob1, rho, obj, data_name)
    MT2_KMNIST_CDVR_commun_result.append(loss_MT2_KMNIST_CDVR_commun)

prob1 is 0.35484
eta is 0.04032
epoch: 1, loss: 25.3113583565, time pass: 0s | CDVR MT2 KMNIST
epoch: 100, loss: 24.2383908272, time pass: 14s | CDVR MT2 KMNIST
epoch: 200, loss: 23.6002992630, time pass: 28s | CDVR MT2 KMNIST
epoch: 300, loss: 23.1946971893, time pass: 42s | CDVR MT2 KMNIST
epoch: 400, loss: 22.7914344788, time pass: 60s | CDVR MT2 KMNIST
epoch: 500, loss: 22.4591097832, time pass: 79s | CDVR MT2 KMNIST
epoch: 600, loss: 22.1677553177, time pass: 97s | CDVR MT2 KMNIST
epoch: 700, loss: 21.8921775818, time pass: 117s | CDVR MT2 KMNIST
epoch: 800, loss: 21.6207988739, time pass: 136s | CDVR MT2 KMNIST
epoch: 900, loss: 21.3668675423, time pass: 156s | CDVR MT2 KMNIST
epoch: 1000, loss: 21.1212728500, time pass: 175s | CDVR MT2 KMNIST
num_commun: 1, loss: 25.3113583565, time pass: 0s | CDVR MT2 KMNIST
num_commun: 100, loss: 23.3517419815, time pass: 36s | CDVR MT2 KMNIST
num_commun: 200, loss: 22.3792099953, time pass: 72s | CDVR MT2 KMNIST
num_commun: 300, loss: 21.6364

In [8]:
with open("./result/MT2_KMNIST_CDVR_iter_result.txt", "wb") as f:   #Pickling
    pickle.dump(MT2_KMNIST_CDVR_iter_result, f)
    
with open("./result/MT2_KMNIST_CDVR_commun_result.txt", "wb") as f:   #Pickling
    pickle.dump(MT2_KMNIST_CDVR_commun_result, f)