In [None]:
from pathlib import Path
from zipfile import ZipFile
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib import rcParams

In [None]:
def parse_zip_name(zip_path):
    name = zip_path.name
    parts = name.split('_')
    dataset = parts[0]
    model = parts[1]
    optimizer = parts[3]
    return {'dataset': dataset,
           'model': model,
           'optimizer': optimizer,
           'path': zip_path}

def get_results(result_path):
    zip_path_list = [i for i in result_path.iterdir() if i.suffix=='.zip']
    return [parse_zip_name(i) for i in zip_path_list]

def get_dataframe(result_dict: dict):
    with ZipFile(result_dict['path']) as result_zip:
        with result_zip.open('history.csv') as csv_file:
            return pd.read_csv(csv_file)
        
def pair_results(result_list_1, result_list_2, result_list_3):
    pairs = []
    for i in result_list_1:
        for j in result_list_2:
            for k in result_list_3:
                c1 = i['dataset']==j['dataset']==k['dataset']
                c2 = i['model']==j['model']==k['model']
                if c1 and c2:
                    pairs.append((i, j, k))
    return pairs

In [None]:
results_dir = Path('results')
gravity_dir = Path('gravity_benchmarks')
adam_dir = Path('adam_benchmarks')
rmsprop_dir = Path('rmsprop_benchmarks')
vgg16_dir = Path('vgg16')
vgg19_dir = Path('vgg19')

gravity_vgg16_results = get_results(Path.joinpath(results_dir, gravity_dir, vgg16_dir))
gravity_vgg19_results = get_results(Path.joinpath(results_dir, gravity_dir, vgg19_dir))
adam_vgg16_results = get_results(Path.joinpath(results_dir, adam_dir, vgg16_dir))
adam_vgg19_results = get_results(Path.joinpath(results_dir, adam_dir, vgg19_dir))
rmsprop_vgg16_results = get_results(Path.joinpath(results_dir, rmsprop_dir, vgg16_dir))
rmsprop_vgg19_results = get_results(Path.joinpath(results_dir, rmsprop_dir, vgg19_dir))
                   
gravity_results = gravity_vgg16_results + gravity_vgg19_results
adam_results = adam_vgg16_results + adam_vgg19_results
rmsprop_results = rmsprop_vgg16_results + rmsprop_vgg19_results                   

In [None]:
pairs = pair_results(gravity_results, adam_results, rmsprop_results)

In [None]:
pair_dfs = [(get_dataframe(i), get_dataframe(j), get_dataframe(k)) for i, j, k in pairs]

In [None]:
pairs[0]

In [None]:
png_metadata = {
'Title': '',
'Author': 'Dariush Bahrami, Sadegh Pourianzade',
'Software': 'Python, matplotlib',
}

In [None]:
svg_metadata = {'Title': '',
                'Contributor': ['Dariush Bahrami', 'Sadegh Pourianzade']}

In [None]:
figs_dir = Path('figures')

In [None]:
data = 'val_accuracy'
plt.style.use(['default'])
rcParams['font.sans-serif'] = ['Times New Roman']
rcParams['xtick.labelsize']='small'
rcParams['ytick.labelsize']='small'
rcParams["legend.frameon"] = False
rcParams['lines.linewidth'] = 1.5
correct_y_label = {'loss': 'Loss', 'accuracy': 'Accuracy',
                   'val_loss': 'Validation Loss',
                   'val_accuracy': 'Validation Accuracy'}
colors = {'Gravity': 'red', 'Adam': 'blue', 'RMSprop': 'green'}
models = {'vgg16': 'VGG16', 'vgg19': 'VGG19'}
for index in range(len(pairs)):
    fig, axis = plt.subplots(nrows=2, ncols=2, dpi=600)
    fig.set_size_inches(7.5, 6)
    
    i = 0
    indice = [(0, 0), (0, 1), (1, 0), (1, 1)]
    titles = ['(a)', '(b)', '(c)', '(d)']

    for data in ['loss', 'accuracy', 'val_loss', 'val_accuracy']:
        ax = axis[indice[i]]
        for df, info in zip(pair_dfs[index][::-1], pairs[index][::-1]):
            ax.plot(df['epochs'], df[data], label=info['optimizer'], color=colors[info['optimizer']])
            ax.set_title(titles[i])
            ax.set_xlabel('Epochs', fontsize='medium')
            ax.set_ylabel(correct_y_label[data], fontsize='medium')
            ax.legend(fontsize='x-small')
        i+=1
        
    title = f"Dataset: {info['dataset']} - Architecture: {models[info['model']]}"
    fig.suptitle(title, fontsize='large')

    plt.tight_layout()
#     valid_name = title.replace('-', '_').replace(':', '_').replace(' ', '').replace('(', '_').replace(')', '_').lower()
#     name = Path(valid_name + '.jpg')
#     format_dir = Path('jpg')
#     fname = Path.joinpath(figs_dir, format_dir, name)
#     plt.savefig(fname, dpi=600,
#                 orientation='landscape',
#                 transparent=True)
    plt.show()