In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import time
from sklearn.datasets import fetch_openml


import warnings
warnings.filterwarnings('ignore', category=UserWarning)

from DSGT_functions import *
from DiTree import *

In [None]:
# download and reformat dataset
dataset = fetch_openml('mnist_784', version=1)
train_set_size = 50_000
test_set_size = 10_000
U0, v0 = dataset['data'], dataset['target']
U = U0.astype(np.double)
v = v0.astype(np.uint8)

In [None]:
# binary classification on detecting '5', labels {-1, +1}
v_bin_5_lst = [2*int(v[i] == 5)-1 for i in range(len(v))]

In [None]:
# create a single dataframe of attributes + labels

df_U = pd.DataFrame(data=U)
df_v = pd.DataFrame(data=np.asarray(v_bin_5_lst), columns=['label'])
df_data_merged = pd.concat([df_U, df_v.reindex(df_U.index)], axis=1)
df_data_merged.head()

In [None]:
"""a dictionary with key:value == name of run : (x_sol, f_vals, h_vals, consensus_err_vals)"""
all_samplepaths = {}
num_properties = 4

In [None]:
### ISR-DSGT ###
m = 50

batch_size = 1
step_size_coeff = 1

data_impl = df_data_merged
#data_impl = df_data_merged_rnd

train_set, train_sets, test_set = split_train_test(data_impl, train_set_size, test_set_size, m)

n_attributes = len(data_impl.columns)-1
N = train_set_size
mu_param = 10**-2

x0_DSGT = np.zeros((m, n_attributes))
# x0_DSGT = np.random.rand(m,n_attributes)
x0_DSGT_copy = x0_DSGT

G_DSGT, W_DSGT = sparse_tree_graph(m, m//4)

x0_DSGT = x0_DSGT_copy
ISR_samplepaths = []
epochs = []

num_samplepaths = 1
max_T = 5

# eta_params: {10, 5, 2}
# K_params: {10, 5, 2}
eta_param = 5
K_param = 10

t1_ISR = time.time()
for i in range(num_samplepaths):
    x_sol0, f_vals0, h_vals0, consensus_err_vals0, epochs = seq_DSGT(train_sets, x0_DSGT, W_DSGT, step_size_coeff, mu_param, batch_size, max_T, eta_param, K_param)
    ISR_samplepaths.append((x_sol0, f_vals0, h_vals0, consensus_err_vals0))
t2_ISR = time.time()
avg_time_ISR = (t2_ISR - t1_ISR) / num_samplepaths

key = "ISR-DSGT"
all_samplepaths[key] = []
for k in range(num_properties):
    avg_kth_prop_val = sum(ISR_samplepaths[i][k] for i in range(num_samplepaths)) / num_samplepaths
    all_samplepaths[key].append(avg_kth_prop_val)
all_samplepaths[key] = tuple(all_samplepaths[key])

print(f'{key} with {m} nodes:    f(x_{epochs[-1]}) = {all_samplepaths[key][1][-1]:.3f}         time(sec) = {avg_time_ISR:.1f}')

In [None]:
step_size_coeff = step_size_coeff
eta_coeff = 1

DSGT_samplepaths = []
t1_DSGT = time.time()
# repeat DSGT 5 times and average out the results
for i in range(num_samplepaths):
    x_sol0, f_vals0, h_vals0, consensus_err_vals0 = iter_DSGT(train_sets, epochs, x0_DSGT, W_DSGT, step_size_coeff, eta_coeff, mu_param, batch_size)
    DSGT_samplepaths.append((x_sol0, f_vals0, h_vals0, consensus_err_vals0))
t2_DSGT = time.time()
avg_time_DSGT = (t2_DSGT - t1_DSGT) / num_samplepaths

key = "IR-DSGT"
all_samplepaths[key] = []
for k in range(num_properties):
    avg_kth_prop_val = sum(DSGT_samplepaths[i][k] for i in range(num_samplepaths)) / num_samplepaths
    all_samplepaths[key].append(avg_kth_prop_val)
all_samplepaths[key] = tuple(all_samplepaths[key])

print(len(all_samplepaths[key][1]))

print(f'{key} with {m} nodes:    f(x_{epochs[-1]}) = {all_samplepaths[key][1][-1]:.3f}         time(sec) = {avg_time_DSGT:.1f}')

In [None]:
### SGD ###
batch_size = m

x0_SGD = (x0_DSGT.mean(0)).reshape((1,n_attributes))
print(x0_SGD.shape)
G_SGD, W_SGD = complete_graph(1)

SGD_samplepaths = []
t1_SGD = time.time()
for i in range(num_samplepaths):
    x_sol1, f_vals1, h_vals1, consensus_err_vals1 = iter_DSGT(train_sets, epochs, x0_SGD, W_SGD, step_size_coeff, eta_coeff, mu_param, batch_size)
    SGD_samplepaths.append((x_sol0, f_vals0, h_vals0, consensus_err_vals0))
t2_SGD = time.time()
avg_time_SGD = (t2_SGD - t1_SGD) / num_samplepaths

key = "batch SGD"
all_samplepaths[key] = []
for k in range(num_properties):
    avg_kth_prop_val = sum(SGD_samplepaths[i][k] for i in range(num_samplepaths)) / num_samplepaths
    all_samplepaths[key].append(avg_kth_prop_val)
all_samplepaths[key] = tuple(all_samplepaths[key])

for key in all_samplepaths:
    print(f'{key}:    f(x_{epochs[-1]}) = {all_samplepaths[key][1][-1]:.3f}         time(sec) = {avg_time_SGD:.1f}')

In [None]:
print(epochs)

print(len(all_samplepaths[key][1].tolist()))
print(key)
print(all_samplepaths[key][1])

In [None]:
# sample path of 1 run of DSGT
from itertools import cycle

color_cycle = cycle(('r','g','b'))
marker_cycle = cycle('sP.')
msize_cycle = cycle((7,7,5))
linestyle_cycle = cycle(('dotted', 'dotted', 'dotted'))
lwidth = 4
alpha_cycle = cycle((0.6,0.6,0.6))

adj_const = 1.1*min([0,min(f_vals0), min(f_vals1)])+1

fig = plt.figure(figsize=(30,15))

print(epochs)

# f_i plot
plt.subplot(221)
for key in all_samplepaths:
    print(key)
    plt.plot(epochs[1:],
             (np.log(all_samplepaths[key][1]+adj_const)).tolist()[1:],
             color=next(color_cycle),
             marker=next(marker_cycle),
             markersize=next(msize_cycle),
             linestyle=next(linestyle_cycle),
             label=key,
             linewidth=lwidth,
             alpha=next(alpha_cycle))
plt.legend(loc=3,fontsize=12)
plt.xlabel('Iteration', color='#1C2833',fontsize=14)
plt.ylabel(r"$g(x) = \ln\left(\sum_{i=1}^mf_i\left(\frac{\mathbf{1}^T_mx_k}{m}\right)+cst\right)$", color='#1C2833',fontsize=18)
plt.grid(True)

# consensus plot
plt.subplot(222)
for key in all_samplepaths:
    plt.plot(epochs[1:],
             all_samplepaths[key][3].tolist()[1:],
             color=next(color_cycle),
             marker=next(marker_cycle),
             markersize=next(msize_cycle),
             linestyle=next(linestyle_cycle),
             label=key,
             linewidth=lwidth,
             alpha=next(alpha_cycle))
plt.legend(loc=1,fontsize=12)
plt.xlabel('Iteration', color='#1C2833',fontsize=14)
plt.ylabel(r"consensus: $\left\|\mathbf{x}_k-\mathbf{1}_m\bar{x}_k\right\|$", color='#1C2833',fontsize=18)
plt.grid(True)

# h_i plot
plt.subplot(223)
for key in all_samplepaths:
    plt.plot(epochs[1:],
             all_samplepaths[key][2].tolist()[1:],
             color=next(color_cycle),
             marker=next(marker_cycle),
             markersize=next(msize_cycle),
             linestyle=next(linestyle_cycle),
             label=key,
             linewidth=lwidth,
             alpha=next(alpha_cycle))
plt.legend(loc=1,fontsize=12)
plt.xlabel('Iteration', color='#1C2833',fontsize=14)
plt.ylabel(r"$f(x)= \frac{1}{2}\|x\|^2$", color='#1C2833',fontsize=18)
plt.grid(True)

# network plot
plt.subplot(224)
plt.title(f"Tree graph with {m} nodes".format(m))
draw_graph(G_DSGT)
#nx.draw_circular(G_DSGT)