In [14]:
import numpy as np
from utils import GetStats as stats
import matplotlib.pyplot as plt
from matplotlib.ticker import MaxNLocator
import numpy as np
import re

%matplotlib inline

In [15]:
from plotly.subplots import make_subplots
import plotly.graph_objects as go

In [37]:
total_step_valid = 3167
total_step = 6471
epochs = 14

In [17]:
def chunks(logs, n_step):
    """Splits the list of logs into chunks of n_steps and
    returns the mean value per epoch"""
    epoch_means = []
    for i in range(0, len(logs), n_step):
        epoch_logs = logs[i:i + n_step]
        epoch_logs = list(map(float, epoch_logs)) # convert list of strings to float
        epoch_mean = np.mean(epoch_logs)
        epoch_means.append(epoch_mean)
    return epoch_means

In [19]:
stats = GetStats(log_train_path = 'training_log.txt', 
                 log_valid_path = 'validation_log.txt', 
                 bleu_path = 'bleu.txt')

In [20]:
losses_train, perplex_train = stats.get_train_log()
losses_val, perplex_val = stats.get_valid_log()
bleu_1_scores, bleu_2_scores, bleu_3_scores, bleu_4_scores = stats.get_bleu()

In [87]:
losses_val = [float(i) for i in losses_val]
perplex_val = [float(i) for i in perplex_val]

In [88]:
bleu_1_scores = [float(i) for i in bleu_1_scores]
bleu_2_scores = [float(i) for i in bleu_2_scores]
bleu_3_scores = [float(i) for i in bleu_3_scores]
bleu_4_scores = [float(i) for i in bleu_4_scores]

In [89]:
train_epoch_loss = chunks(losses_train, total_step)
valid_epoch_loss = chunks(losses_val, total_step_valid)

In [90]:
train_epoch_perplex = chunks(perplex_train, total_step)
valid_epoch_perplex  = chunks(perplex_val, total_step_valid)

In [94]:
fig = make_subplots(rows=1, cols=1,
                    specs=[[{"secondary_y": True}]])

fig.add_trace(
    go.Scatter(x = np.arange(1, epochs+1), y=train_epoch_loss, name="Training",  marker_color = '#cf3721'),
    row=1, col=1, secondary_y=False)

fig.add_trace(
    go.Scatter(y=valid_epoch_loss, name="Validation",  marker_color = '#31aeb8'), 
    row=1, col=1, secondary_y=False, 
)


fig.update_layout(template = 'plotly_white', 
                  title_text = 'Train vs Validation loss',
                 autosize=False,
    width=600,
    height=500,
    margin=dict(
        l=50,
        r=50,
        b=100,
        t=100,
        pad=4))
fig.update_xaxes(title_text='epoch')
fig.update_yaxes(title_text='loss')

                  
fig.show()
fig.write_html("loss.html")

In [97]:
fig = make_subplots(rows=1, cols=1,
                    specs=[[{"secondary_y": True}]])

fig.add_trace(
    go.Scatter(y=train_epoch_perplex, name="Training",  marker_color = '#cf3721'),
    row=1, col=1, secondary_y=False)

fig.add_trace(
    go.Scatter(y=valid_epoch_perplex, name="Validation",  marker_color = '#31aeb8'), 
    row=1, col=1, secondary_y=False, 
)


fig.update_layout(template = 'plotly_white', 
                  title_text = 'Train vs Validation perplexity',
                 autosize=False,
    width=600,
    height=500,
    margin=dict(
        l=50,
        r=50,
        b=100,
        t=100,
        pad=4))

fig.update_xaxes(title_text='epoch')
fig.update_yaxes(title_text='perplexity')
fig.show()


fig.write_html("perplex.html")

In [99]:
fig = make_subplots(rows=2, cols=2,
                    specs=[[{"secondary_y": True}, {"secondary_y": True}],
                          [{"secondary_y": True}, {"secondary_y": True}]],
                    subplot_titles=("BLEU-1", "BLEU-2", "BLEU-3", "BLEU-4"))

fig.add_trace(
    go.Scatter(y=bleu_1_scores,  marker_color = '#258039', line_width=1.2),
    row=1, col=1, secondary_y=False)

fig.add_trace(
    go.Scatter(y=bleu_2_scores,  marker_color = '#258039', line_width=1.2),
    row=1, col=2, secondary_y=False)

fig.add_trace(
    go.Scatter(y=bleu_3_scores,  marker_color = '#258039', line_width=1.2), 
    row=2, col=1, secondary_y=False, 
)
fig.add_trace(
    go.Scatter(y=bleu_4_scores, marker_color = '#258039', line_width=1.2), 
    row=2, col=2, secondary_y=False, 
)



fig.update_layout(template = 'plotly_white', 
                 autosize=False,
                  showlegend=False,
    width=1000,
    height=600,
    margin=dict(
        l=50,
        r=50,
        b=100,
        t=100,
        pad=4))

                  
fig.show()
fig.write_html("bleu.html")