In [None]:
import os
import glob
import json
import matplotlib.pyplot as plt

import plotly
plotly.offline.init_notebook_mode(connected=False)
import plotly.graph_objs as go

In [None]:
dirs = glob.glob('output/*')
dirs.sort()

data = []
colors = ['#4363d8', '#f58231', '#fabebe', '#e6beff', '#800000', '#000075', '#a9a9a9', '#ffe119', '#000000']

fig2, axes = plt.subplots(2, 2, figsize=(12, 5))

for dir in dirs:

    csvpath = os.path.join(dir, 'history.csv')
    if not os.path.exists(csvpath):
        continue
       
    if 'output/02' not in dir:
        continue
        
    with open(os.path.join(dir, 'args.json'), 'r') as fp:
        args = json.load(fp)
    noise = args['noise']

    color = colors[int(len(data) / 2)]
    name = dir.replace('output/', '')
    df = pd.read_csv(csvpath, index_col=0)
    df.index += 1
    data += [
        go.Scatter(x=df.index, y=df.loss,
                   line=dict(color=color),
                   name=name + f' (noise={noise})'),
        go.Scatter(x=df.index, y=df.val_loss,
                   line=dict(color=color, dash='dot'),
                   name=name + f' (val noise={noise})')
    ]
    
    param = dict(marker='.')
    df.loss.plot(ax=axes[0, 0], label=f'[-{noise}, {noise}]', **param)
    df.loss.plot(ax=axes[1, 0], **param)
    df.val_loss.plot(ax=axes[0, 1], **param)
    df.val_loss.plot(ax=axes[1, 1], **param)
    
layout = dict(
    height=400,
    xaxis=dict(title='epoch'),
    yaxis=dict(title='loss (MSE)', hoverformat='.5f', range=[0, 0.0015]),
)
fig = go.Figure(data, layout)
plotly.offline.iplot(fig)

axes[0, 0].legend(loc='best')
axes[0, 0].set_ylim(0, 0.4)
axes[0, 1].set_ylim(0, 0.4)
axes[1, 0].set_ylim(0, 0.0008)
axes[1, 1].set_ylim(0, 0.0008)
for ax in axes.flatten():
    ax.set_ylabel('loss (MSE)')
fig2.tight_layout()
fig2.savefig('output/loss.eps')