In [None]:
import os
import glob
import json
import matplotlib.pyplot as plt

import plotly
plotly.offline.init_notebook_mode(connected=False)
import plotly.graph_objs as go
from plotly import tools

In [None]:
dirs = glob.glob('output/*')
dirs.sort()

fig = tools.make_subplots(1, 2, shared_yaxes=True, print_grid=False, subplot_titles=['training', 'validation'])
colors = iter(['#4363d8', '#f58231', '#fabebe', '#e6beff', '#800000', '#000075', '#a9a9a9', '#ffe119', '#000000'])
results = []

fig2, axes = plt.subplots(1, 2, figsize=(12, 4))

for dir in dirs:

    csvpath = os.path.join(dir, 'history.csv')
    if not os.path.exists(csvpath):
        continue
        
    with open(os.path.join(dir, 'args.json'), 'r') as fp:
        args = json.load(fp)
    noise = args['noise']
    if args['optimizer'] != 'adam':
        continue

    color = next(colors)
    name = dir.replace('output/', '')
    df = pd.read_csv(csvpath, index_col=0)
    df.index += 1
    param = dict(line=dict(color=color), name=f'({noise}) {name}', legendgroup=name, mode='lines') 
    fig.add_trace(go.Scatter(x=df.index, y=df.loss, **param), 1, 1)
    fig.add_trace(go.Scatter(x=df.index, y=df.val_loss, **param, showlegend=False), 1, 2)

    df.loss.plot(ax=axes[0], label=f'[-{noise}, {noise}]', marker='.')
    df.val_loss.plot(ax=axes[1], label=f'[-{noise}, {noise}]', marker='.')

    results.append(dict(name=name, noise=noise, minloss=df.loss.min(), minvalloss=df.val_loss.min()))

fig['layout'].update(height=300)
fig['layout']['yaxis'].update(title='loss (MSE)', hoverformat='.5f', range=[0, 0.0015])
plotly.offline.iplot(fig)

for ax in axes.flatten():
    ax.legend(loc='best')
    ax.set_xlim(-1, 70)
    ax.set_ylim(0, 0.0008)
    ax.set_ylabel('loss (MSE)')
fig2.tight_layout()
fig2.savefig('output/loss.eps')

results = pd.DataFrame(results)
results.to_csv('output/loss.csv')
results[['noise', 'minloss', 'minvalloss']].to_latex('output/loss.tex')
results

In [None]:
#find epoch with the smallest validation loss

df = pd.read_csv('output/0415-200054/history.csv', index_col='epoch')
df.index += 1
print(df.sort_values('val_loss').head(5))

print('\nTraining stopped at', df.index.max(), 'eopch.')