In [None]:
import os
from tensorboard.backend.event_processing import event_accumulator
import datetime
import time
import matplotlib.pyplot as plt
import torch
import numpy as np

In [None]:
def get_curve_bleu(basepath,dataset='iwslt14.tokenized.de-en',label=''):
    file = basepath+'/'+dataset+'/'+label+'.txt'
    bleus = []
    with open(file, 'r') as f:
        text = f.readlines()
        for line in text:
            if line.strip() == '':
                continue
            start = line.index('=')
            end = line.index(',')
            bleus.append(float(line[start+1:end].strip()))
    return bleus
        
def plot_bleu(basepath,label_list, dataset, ymin=0, ymax=20,interval=1):
    
    title={'iwslt14.tokenized.fr-en':'IWSLT14-fr-en',
           'iwslt17.tokenized.de-en': 'IWSLT17-de-en',
           'iwslt14.tokenized.de-en':'IWSLT14-de-en'}
    #plt.figure()
    plt.figure(figsize=(6, 4))
    plt.title('The BLEUs on {}'.format(title[dataset]))
    if interval > 1:
        plt.xlabel('Epoch * {}'.format(str(interval)))
    else:
        plt.xlabel('Number of Epoch')
    
    plt.ylabel('BLEU Value')
    plt.ylim(ymin,ymax)
    
    #for optim in optimizers:
    for label in label_list:
        bleus =  np.array(get_curve_bleu(basepath,dataset,label[0]))
        plt.plot(bleus, label = label[0], ls = label[1])
        
    plt.legend()
    
    
    
    save_dir = 'visualization/' + dataset + '/'
    if os.path.isdir(save_dir):
        pass
    else:
        os.mkdir(save_dir)

    plt.savefig(save_dir+dataset+'_'+label[0]+'bleu.png',format='png', dpi=1000)
    plt.show()        
        

    

# iwslt14-de-en

In [None]:
label_list = [('adam_cyc_nshrink_5e-4', '-'), \
             ('adam_cyc_yshrink_5e-4', '-'), \
             ('adam_inv_5e-4', '--'), \
             ('sgd_cyc_nshrink_6.9', '-'), \
             ('adam_inv_3e-4', '--'), \
             ('sgd_inv_30', '--'), \
             ]
plot_bleu(basepath='./curve_bleu',label_list=label_list,dataset='iwslt14.tokenized.de-en', \
           ymin=0, ymax=35)

# iwslt17-de-en

In [None]:
label_list = [('adam_cyc_nshrink_7.6e-4', '-'), \
             ('adam_cyc_yshrink_7.6e-4', '-'), \
             ('adam_inv_7.6e-4', '--'), \
             ('adam_inv_5e-4', '--'), \
             ('adam_inv_3e-4', '--'), \
             ('sgd_cyc_nshrink_8', '--'), \
             ('sgd_inv_30', '--'), \
             ]
plot_bleu(basepath='./curve_bleu',label_list=label_list,dataset='iwslt17.tokenized.de-en', \
        ymin=0, ymax=35)

# iwslt14-fr-en

In [None]:
label_list = [
             ('adam_cyc_yshrink_8e-4','-'), \
             ('adam_cyc_nshrink_8e-4','-'), \
             ('adam_inv_8e-4', '--'), \
             ('adam_inv_5e-4', '--'), \
             ('adam_inv_3e-4', '--'), \
             ('adam_inv_1e-5', '--'), \
             ('sgd_inv_30', '--'), \
             ]
plot_bleu(basepath='./curve_bleu',label_list=label_list,dataset='iwslt14.tokenized.fr-en', \
        ymin=0, ymax=40)