In [None]:
import os
from tensorboard.backend.event_processing import event_accumulator
import datetime
import time
import matplotlib.pyplot as plt
import torch
import numpy as np

In [None]:
def get_logs_data(basepath,dataset, label, curve_type='valid',interval=1,maxepoch=50):
    path_t={'adam_cyc_nshrink_MLR2':'adam_cyc_nshrink_1.6e-3',
            'adam_cyc_yshrink_MLR2':'adam_cyc_yshrink_1.6e-3',
            'adam_cyc_nshrink_MLR1':'adam_cyc_nshrink_5e-4',
            'adam_cyc_yshrink_MLR1':'adam_cyc_yshrink_5e-4',
            'adam_cyc_nshrink_5e-4_4096':'adam_cyc_nshrink_5e-4',
            'sgd_cyc_nshrink_MLR2':'sgd_cyc_nshrink_23.2',
            'sgd_cyc_yshrink_MLR2':'sgd_cyc_yshrink_23.2',
            'sgd_cyc_nshrink_MLR1':'sgd_cyc_nshrink_6.9',
            'sgd_cyc_yshrink_MLR1':'sgd_cyc_yshrink_6.9'}
    if label in path_t.keys():
        label=path_t[label]
    root = basepath+'/'+dataset+'-'+label
    #print(root)
    files = []
    for line in os.listdir(root+'/'+curve_type):
        files.append(root+'/'+curve_type+'/'+line)
    files = sorted(files)
    
    ea=event_accumulator.EventAccumulator(files[0])
    ea.Reload()
    num = len(ea.scalars.Items('loss'))
    loss = []
    for file in files[:]:
        #加载日志数据
        ea=event_accumulator.EventAccumulator(file)
        ea.Reload()
        try:
            loss.extend([(i.step, i.value) for i in ea.scalars.Items('loss')])
        except:
            loss.append(loss[-1])
    loss = sorted(loss, key = lambda x : x[0])
    return [i[1] for i in loss[:num*maxepoch:interval]]  #  loss


In [None]:
def plot_log(basepath,dataset, model, label_list, curve_type='train', ymin=0, ymax=20,interval=1,index=0):
    
    title={'iwslt14-fr-en':'IWSLT14-fr-en',
           'iwslt17-de-en': 'IWSLT17-de-en',
           'iwslt14-de-en':'IWSLT14-de-en'}

    plt.figure(figsize=(6, 4))
    plt.title('{} Loss for {} on {}'.format(curve_type.capitalize(), model, title[dataset]))
    if curve_type =='train':
        plt.xlabel('Iteration * {}'.format(str(interval)))
    else:
        plt.xlabel('Number of Epoch')
    plt.ylabel('{} Loss %'.format(curve_type.capitalize()))
    plt.ylim(ymin,ymax)
    

    for label_tuple in label_list:
        label = label_tuple[0]
        loss =  np.array(get_logs_data(basepath,dataset,label,curve_type, interval, maxepoch=50))

        plt.plot(loss, ls=label_tuple[1],label = label)

    plt.legend()

    save_dir = 'visualization/'+ dataset.lower()+'/'
    if os.path.isdir(save_dir):
        pass
    else:
        os.mkdir(save_dir)
    
    plt.savefig(save_dir+curve_type+'_'+index+'.png',format='png', dpi=1000)
    plt.show()

    

# test on iwslt14-de-en

In [None]:
label_list = [
             ('adam_cyc_nshrink_5e-4', '-'), \
             ('adam_cyc_yshrink_5e-4', '-'), \
             ('adam_inv_1e-3', '--'), \
             ('adam_inv_5e-4', '--'), \
             ('adam_inv_3e-4', '--'), \
             ('adam_inv_1e-5', '--'), \
             ]

plot_log(basepath='./tensorboardLog', dataset='iwslt14-de-en', model='transformer', label_list=label_list, \
         curve_type='valid',  ymin=2, ymax=14,index='Figure-3')

In [None]:
label_list = [
             ('sgd_cyc_nshrink_6.9', '-'), \
             ('sgd_cyc_yshrink_6.9', '-'), \
             ('sgd_inv_0.1', '--'), \
             ('sgd_inv_1', '--'), \
             ('sgd_inv_10', '--'), \
             ('sgd_inv_20', '--'), \
             ('sgd_inv_30', '--'), \
             ]

plot_log(basepath='./tensorboardLog', dataset='iwslt14-de-en', model='transformer', label_list=label_list, \
         curve_type='valid',  ymin=2, ymax=18,index='Figure-4')


In [None]:
label_list = [('sgd_cyc_nshrink_6.9', '-'), \
              ('sgd_cyc_yshrink_6.9', '-'), \
              ('adam_cyc_nshrink_5e-4', '--'), \
              ('adam_cyc_yshrink_5e-4', '--'), \
             ]
plot_log(basepath='./tensorboardLog', dataset='iwslt14-de-en', model='transformer', label_list=label_list, \
         curve_type='valid',  ymin=2, ymax=12,index='Figure-5')

In [None]:
label_list = [
             ('adam_cyc_nshrink_5e-4_4096', '-'), \
             ('adam_cyc_nshrink_5e-4_1024', '-'), \
             ('adam_cyc_nshrink_5e-4_256', '-'), \
             ]

plot_log(basepath='./tensorboardLog', dataset='iwslt14-de-en', model='transformer', label_list=label_list, \
         curve_type='valid',  ymin=3, ymax=25,index='Figure-6')


In [None]:
label_list = [
             ('sgd_cyc_nshrink_MLR2', '-'), \
             ('sgd_cyc_yshrink_MLR2', '-'), \
             ('sgd_cyc_nshrink_MLR1', '--'), \
             ('sgd_cyc_yshrink_MLR1', '--'), \
             ]
plot_log(basepath='./tensorboardLog', dataset='iwslt14-de-en', model='transformer', label_list=label_list, \
         curve_type='valid',  ymin=3, ymax=10,index='Figure-8')


In [None]:
label_list = [
             ('adam_cyc_nshrink_MLR2', '-'), \
             ('adam_cyc_yshrink_MLR2', '-'), \
             ('adam_cyc_nshrink_MLR1', '--'), \
             ('adam_cyc_yshrink_MLR1', '--'), \
             ]
plot_log(basepath='./tensorboardLog', dataset='iwslt14-de-en', model='transformer', label_list=label_list, \
         curve_type='valid',  ymin=2, ymax=15,index='Figure-9')

# test on iwslt17-de-en

In [None]:
label_list = [('adam_cyc_nshrink_7.6e-4', '-'),\
              ('adam_cyc_yshrink_7.6e-4', '-'),\
              ('adam_inv_7.6e-4', '--'), \
              ('adam_inv_5e-4', '--'), \
              ('adam_inv_3e-4', '--'), \
              ('sgd_cyc_nshrink_8', '-'), \
              ('sgd_inv_30', '--'), \
             ]
plot_log(basepath='./tensorboardLog', dataset='iwslt17-de-en', model='transformer', label_list=label_list, \
         curve_type='valid',  ymin=1, ymax=12,index='Figure-11')

# Test on iwslt14-fr-en

In [None]:
label_list = [
             ('adam_cyc_nshrink_8e-4','-'), \
             ('adam_cyc_yshrink_8e-4','-'), \
             ('adam_inv_1e-3', '--'), \
             ('adam_inv_8e-4', '--'), \
             ('adam_inv_5e-4', '--'), \
             ('adam_inv_8e-4', '--'), \
             ('adam_inv_1e-5', '--'), \
             ('sgd_inv_30', '--'), \
             ]
plot_log(basepath='./tensorboardLog', dataset='iwslt14-fr-en', model='transformer', label_list=label_list, \
         curve_type='valid',  ymin=2, ymax=20,index='Figure-12')
