In [None]:
from pathlib import Path
from zipfile import ZipFile
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

In [None]:
plt.style.use(['default'])
gravity_paper_params = {'legend.fontsize': 'x-small',
                        'font.sans-serif': ['Times New Roman'],
                        'legend.frameon': False,
                        'lines.linewidth': 1.5,
                        'axes.labelsize': 'medium',
                        'axes.titlesize':'large',
                        'xtick.labelsize':'small',
                        'ytick.labelsize':'small',
                        'figure.titlesize': 'large',
                        'figure.dpi': 600,
                        'savefig.dpi': 600}
plt.rcParams.update(gravity_paper_params)

In [None]:
png_metadata = {
'Title': '',
'Author': 'Dariush Bahrami, Sadegh Pourianzade',
'Software': 'Python, matplotlib',
}

In [None]:
svg_metadata = {'Title': '',
                'Contributor': ['Dariush Bahrami', 'Sadegh Pourianzade']}

In [None]:
figs_dir = Path('figures')

# $\beta$ Visualization

In [None]:
fig, axis = plt.subplots(nrows=2, ncols=1)
fig.set_size_inches(3.5, 6)
fig.suptitle('')

steps = np.arange(1, 601)
beta_list = [0.8, 0.9, 0.95]

colors = ['tab:blue', 'tab:green', 'tab:red']

for i, beta in enumerate(beta_list):
    beta_hat = (beta*steps+1)/(steps+2)
    averaged_data_beta = 1/(1-beta)
    averaged_data_beta_hat = 1/(1-beta_hat)

    ax = axis[0]
    ax.axhline(y=beta, linestyle='--',linewidth=0.75, color=colors[i])
    ax.plot(steps, beta_hat, label=r'$\hat{\beta}$ for $\beta$='+ str(beta), color=colors[i])
    ax.set_title('(a)')
    ax.set_xlabel('Minibatch Number')
    ax.set_ylabel(r'$\hat{\beta}$')
#     ax.legend()

    ax = axis[1]
    ax.axhline(y=averaged_data_beta, linewidth=0.75, linestyle='--', color=colors[i])
    ax.plot(steps, averaged_data_beta_hat, label=r'$\hat{\beta}$ for $\beta$='+ str(beta), color=colors[i])
    ax.set_title('(b)')
    ax.set_xlabel('Minibatch Number')
    ax.set_ylabel(r'Averaged Data')
    ax.legend(bbox_to_anchor=(0.92, -0.25),ncol=3)
    
fig.suptitle(r'Behavior of $\hat{\beta}$')
plt.tight_layout()


# valid_name = 'beta_hat_behavior'
# name = Path(valid_name + '.svg')
# format_dir = Path('svg')
# fname = Path.joinpath(figs_dir, format_dir, name)
# svg_metadata['Title'] = 'Beta hat Behavior'
# plt.savefig(fname, dpi=600,
#             orientation='portrait',
#             transparent=True,
#             metadata=svg_metadata)
plt.show()

# Results Visualization

In [None]:
def parse_zip_name(zip_path):
    name = zip_path.name
    parts = name.split('_')
    dataset = parts[0]
    model = parts[1]
    optimizer = parts[3]
    return {'dataset': dataset,
           'model': model,
           'optimizer': optimizer,
           'path': zip_path}

def get_results(result_path):
    zip_path_list = [i for i in result_path.iterdir() if i.suffix=='.zip']
    return [parse_zip_name(i) for i in zip_path_list]

def get_dataframe(result_dict: dict):
    with ZipFile(result_dict['path']) as result_zip:
        with result_zip.open('history.csv') as csv_file:
            return pd.read_csv(csv_file)
        
def pair_results(result_list_1, result_list_2, result_list_3):
    pairs = []
    for i in result_list_1:
        for j in result_list_2:
            for k in result_list_3:
                c1 = i['dataset']==j['dataset']==k['dataset']
                c2 = i['model']==j['model']==k['model']
                if c1 and c2:
                    pairs.append((i, j, k))
    return pairs

In [None]:
results_dir = Path('results')
gravity_dir = Path('gravity_benchmarks')
adam_dir = Path('adam_benchmarks')
rmsprop_dir = Path('rmsprop_benchmarks')
vgg16_dir = Path('vgg16')
vgg19_dir = Path('vgg19')

gravity_vgg16_results = get_results(Path.joinpath(results_dir, gravity_dir, vgg16_dir))
gravity_vgg19_results = get_results(Path.joinpath(results_dir, gravity_dir, vgg19_dir))
adam_vgg16_results = get_results(Path.joinpath(results_dir, adam_dir, vgg16_dir))
adam_vgg19_results = get_results(Path.joinpath(results_dir, adam_dir, vgg19_dir))
rmsprop_vgg16_results = get_results(Path.joinpath(results_dir, rmsprop_dir, vgg16_dir))
rmsprop_vgg19_results = get_results(Path.joinpath(results_dir, rmsprop_dir, vgg19_dir))
                   
gravity_results = gravity_vgg16_results + gravity_vgg19_results
adam_results = adam_vgg16_results + adam_vgg19_results
rmsprop_results = rmsprop_vgg16_results + rmsprop_vgg19_results                   

In [None]:
pairs = pair_results(gravity_results, adam_results, rmsprop_results)

In [None]:
pair_dfs = [(get_dataframe(i), get_dataframe(j), get_dataframe(k)) for i, j, k in pairs]

In [None]:
correct_y_label = {'loss': 'Loss', 'accuracy': 'Accuracy',
                   'val_loss': 'Validation Loss',
                   'val_accuracy': 'Validation Accuracy'}
colors = {'Gravity': 'red', 'Adam': 'blue', 'RMSprop': 'green'}
models = {'vgg16': 'VGG16', 'vgg19': 'VGG19'}
for index in range(len(pairs)):
    fig, axis = plt.subplots(nrows=2, ncols=2)
    fig.set_size_inches(7.5, 6)
    
    i = 0
    indice = [(0, 0), (0, 1), (1, 0), (1, 1)]
    titles = ['(a)', '(b)', '(c)', '(d)']

    for data in ['loss', 'accuracy', 'val_loss', 'val_accuracy']:
        ax = axis[indice[i]]
        for df, info in zip(pair_dfs[index][::-1], pairs[index][::-1]):
            ax.plot(df['epochs'], df[data], label=info['optimizer'], color=colors[info['optimizer']])
            ax.set_title(titles[i])
            ax.set_xlabel('Epochs')
            ax.set_ylabel(correct_y_label[data])
            ax.legend()
        i+=1
        
    title = f"Dataset: {info['dataset']} - Architecture: {models[info['model']]}"
    fig.suptitle(title)

    plt.tight_layout()
#     valid_name = title.replace('-', '_').replace(':', '_').replace(' ', '').replace('(', '_').replace(')', '_').lower()
#     name = Path(valid_name + '.jpg')
#     format_dir = Path('jpg')
#     fname = Path.joinpath(figs_dir, format_dir, name)
#     plt.savefig(fname, dpi=600,
#                 orientation='landscape',
#                 transparent=True)
    plt.show()