# This is a notebook to plot the results of the paper

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import pickle

In [None]:

# change the name of the file to the file that the results were saved in and modify the results accordingly
with open('temp.data', 'rb') as filehandle:
    vLosses = pickle.load(filehandle)    
    tLosses = pickle.load(filehandle)    
    vERRs = pickle.load(filehandle)    
    tERRs = pickle.load(filehandle)    
    vACCs5 = pickle.load(filehandle)    
    tACCs5 = pickle.load(filehandle)    
    avg_GD = pickle.load(filehandle)    
    


In [None]:
gen_loss = [(i-j)*(i>j) for (i,j) in zip(vLosses, tLosses)]
gen_err = [(i-j)*(i>j) for (i,j) in zip(vERRs, tERRs)]
num_epochs = np.array(vLosses).size

# Loss curves

In [None]:
fig, axs = plt.subplots(1, 1, figsize=[16,9])
epochs = np.arange(0, num_epochs, 1)
x = epochs
y1 = vLosses
y2 = tLosses
y3 = gen_loss

t0, = axs.semilogx(x, y1, label='test loss', linewidth=5.0)

t1, = axs.semilogx(x, y2, label='train loss', linewidth=5.0)

t2, = axs.semilogx(x, y3, label='generalization loss', linewidth=5.0)


axs.legend(handles=[t0,t1,t2], prop={'size': 40})

axs.set_xlabel('epochs', fontsize=30)
axs.set_ylabel('Cross entropy loss', fontsize=30)

axs.tick_params(axis="x", labelsize=27)  
axs.tick_params(axis="y", labelsize=27)
# fig.savefig('figures/loss.png')

# Error curves

In [None]:
fig, axs = plt.subplots(1, 1, figsize=[16,9])
epochs = np.arange(0, num_epochs, 1)
x = epochs
y1 = vERRs
y2 = tERRs
y3 = gen_err

t0, = axs.semilogx(x, y1, label='test error', linewidth=5.0)

t1, = axs.semilogx(x, y2, label='train error', linewidth=5.0)

t2, = axs.semilogx(x, y3, label='generalization error', linewidth=5.0)



axs.legend(handles=[t0,t1,t2], prop={'size': 40})

axs.set_xlabel('epochs', fontsize=30)
axs.set_ylabel('Error percentage', fontsize=30)

axs.tick_params(axis="x", labelsize=27)  
axs.tick_params(axis="y", labelsize=27)
#fig.savefig('figures/error.png') 


# Gradient disparity vs Loss

In [None]:
fig, axs = plt.subplots(1, 1, figsize=[16,9])
epochs = np.arange(0, num_epochs, 1)
x = epochs
y3 = gen_loss
y1 = vLosses
color = 'tab:blue'
axs.set_xlabel('epochs', fontsize=30)
axs.set_ylabel('Cross entropy loss', color="black", fontsize=30)
t0, = axs.semilogx(x,y3, color="tab:green", label='generalization loss', linewidth=5.0)
t1, = axs.semilogx(x,y1, color=color, label='test loss', linewidth=5.0)

axs.tick_params(axis='y', labelcolor="black", labelsize=27)
axs.tick_params(axis='x', labelsize=27)

ax2 = axs.twinx()

color = 'tab:red'
y4 = avg_GD

ax2.set_ylabel('Average gradient disparity', color=color, fontsize=30)
t2, = ax2.semilogx(x, y4, color=color, linewidth=5.0, label='gradient disparity')

ax2.tick_params(axis='y', labelcolor=color, labelsize=27)


axs.legend(handles=[t0,t1,t2], prop={'size':40}, loc='upper left')
#fig.savefig('figures/loss_vs_GD.png')

# Gradient disparity vs Error

In [None]:
fig, axs = plt.subplots(1, 1, figsize=[16,9])
epochs = np.arange(0, num_epochs, 1)
y3 = gen_err

color = 'tab:green'
axs.set_ylabel('Generalization error', color=color, fontsize=30)
axs.set_xlabel('Iterations', fontsize=30)
t0, = axs.semilogx(epochs, y3, color=color, linewidth=5.0, label='generalization error')

axs.tick_params(axis='y', labelcolor=color, labelsize=27)
axs.tick_params(axis='x', labelsize=27)

ax2 = axs.twinx()

color = 'tab:red'
ax2.set_ylabel('Average gradient disparity', color=color, fontsize=30)

y4 = avg_GD

t1, = ax2.semilogx(x, y4, color=color, linewidth=5.0, label='gradient disparity')

ax2.tick_params(axis='y', labelcolor=color, labelsize=27)


axs.legend(handles=[t0,t1], prop={'size':40})
#fig.savefig('figures/error_vs_GD.png')