In [None]:
import os
from tensorboard.backend.event_processing import event_accumulator
import datetime
import time
import matplotlib.pyplot as plt
import torch
import numpy as np

In [None]:
def get_logs_data(basepath, dataset, label, curve_type='valid', interval=1, maxepoch=50):
    root = "%s/%s-%s" % (basepath, dataset, label)
    files = []
    for line in os.listdir("%s/%s" % (root, curve_type)):
        files.append("%s/%s/%s" % (root, curve_type, line))
    files = sorted(files)
    ea=event_accumulator.EventAccumulator(files[0])
    ea.Reload()
    num = len(ea.scalars.Items('loss'))
    loss = []
    for file in files[:]:
        ea=event_accumulator.EventAccumulator(file)
        ea.Reload()
        try:
            loss.extend([(i.step, i.value) for i in ea.scalars.Items('loss')])
        except:
            loss.append(loss[-1])
    loss = sorted(loss, key = lambda x : x[0])
    return [i[1] for i in loss[:num*maxepoch:interval]]

# test on iwslt14-de-en

In [None]:
label_list = [
    (
        [
            ('adam_full-cosine_nshrink_5e-4_4096_with-warmup', 4096),
            ('adam_full-cosine_nshrink_5e-4_1024_perturb0_with-warmup', 1024),
            ('adam_full-cosine_nshrink_5e-4_256_perturb0_with-warmup', 256)
        ],
            "adam_cos_nshrink_5e-4", "#e41a1c", "o",  "-"
    ),
    (
        [
            ('adam_full-cosine_nshrink_3e-4_4096_with-warmup', 4096),
            ('adam_full-cosine_nshrink_3e-4_1024_perturb0_with-warmup', 1024),
            ('adam_full-cosine_nshrink_3e-4_256_perturb0_with-warmup', 256)
        ],
            "adam_cos_nshrink_3e-4", "#e41a1c", "o", "--"
    ),
    (
        [
            ('adam_inverse-sqrt_5e-4_4096', 4096),
            ('adam_inverse-sqrt_5e-4_1024', 1024),
            ('adam_inverse-sqrt_5e-4_256', 256)
        ],
            "adam_inv_5e-4", "#377eb8", "X", "-"
    ),
    (
        [
            ('adam_inverse-sqrt_3e-4_4096_with-warmup', 4096),
            ('adam_inverse-sqrt_3e-4_1024', 1024),
            ('adam_inverse-sqrt_3e-4_256', 256)
        ],
            "adam_inv_3e-4", "#80b1d3", "X", "--"
    ),
    (
        [
            ('adam_sine_nshrink_5e-4_4096', 4096),
            ('adam_sine_nshrink_5e-4_1024', 1024),
            ('adam_sine_nshrink_5e-4_256_perturb0', 256)
        ],
            "adam_sine_nshrink_5e-4", "#4daf4a", "*", "-"
    ),
    (
        [
            ('adam_sine_nshrink_3e-4_4096_perturb0', 4096),
            ('adam_sine_nshrink_3e-4_1024_perturb0', 1024),
            ('adam_sine_nshrink_3e-4_256_perturb0', 256)
        ],
            "adam_sine_nshrink_5e-4", "#4daf4a", "*", "--"
    ),
    (
        [
            ('adam_triangular_nshrink_5e-4_4096', 4096),
            ('adam_triangular_nshrink_5e-4_1024', 1024),
            ('adam_triangular_nshrink_5e-4_256', 256)
        ],
            "adam_tri_nshrink_5e-4", "#984ea3", ">", "-"
    ),
    (
        [
            ('adam_triangular_nshrink_3e-4_4096', 4096),
            ('adam_triangular_nshrink_3e-4_1024', 1024),
            ('adam_triangular_nshrink_3e-4_256', 256)
        ],
            "adam_tri_nshrink_3e-4", "#984ea3", ">", "--"
    )
    
]

plt.figure(figsize=(6, 6))
plt.xlabel('Batch size')
plt.ylabel('Validation loss')
for models, label, clr, m, ls in label_list:
    _min_loss = []
    _batches = []
    for _mdl, _batch in models:
        loss =  np.array(get_logs_data("iwslt14_de-en/tensorboardLog", "iwslt14-de-en", _mdl, "valid", 1, maxepoch=50))
        _min_loss.append(np.min(loss))
        _batches.append(_batch)
    plt.plot(_batches, _min_loss, clr, ls=ls, label=label, linewidth=2.5, marker=m, fillstyle='none', markersize=10)
    plt.xticks(_batches)
    plt.yticks(np.arange(3, 13, 1.5))
    plt.legend()
    plt.savefig("fig6.png",format='png', dpi=1000)