In [None]:
import os

from plotly import tools
from plotly.offline import init_notebook_mode, plot, iplot
import plotly.graph_objs as go


init_notebook_mode(connected=True)

In [None]:
def parce_log(log, key_names=['Train', 'Validation']):
    def fix_number(str_num):
        if str_num[-1] == ',':
            str_num = str_num[:-1]
        num_type = int
        if '.' in str_num:
            num_type = float
        return num_type(str_num)
    
    def fix_name(str_name):
        if str_name[-1] == ':':
            str_name = str_name[:-1]
        return str_name

    parsed_log = dict((key,[]) for key in key_names)
    for line in log:
        line = line.split()
        tine_type = line[3]
        if tine_type in key_names:
            keys = [fix_name(line[i]) for i in range(5, len(line), 2)]
            values = [fix_number(line[i]) for i in range(6, len(line), 2)]
            sample_dict = {k:v for (k,v) in zip(keys, values)}
            parsed_log[tine_type].append(sample_dict)
    return parsed_log

def plot_data(log_list: list, log_id=None):
    x_name = 'Epoch'
    log_names = list(log_list[0].keys())
    data = [[] for _  in log_names]

    for sample in log_list:
        for i, name in enumerate(log_names):
            data[i].append(sample[name])
    
    plots = []
    for i in range(1, len(data)):
        if log_id is not None:
            name = log_names[i]+' '+str(log_id)
        else:
            name = log_names[i]
        plots.append(go.Scatter(x=data[0],
                                y=data[i],
                                mode='lines+markers',
                                name=name))
    return plots

def get_plot_data(log_path, key_names, log_id=None):
    data = []
    with open(log_path, 'r') as f:  
        data = f.readlines()

    log = parce_log(data, key_names)
    plots = []
    for name in key_names:
        plots+= plot_data(log[name], log_id)
    return plots

def draw_log(log_path, key_names, log_id=None):
    plots = get_plot_data(log_path, key_names, log_id)
    iplot(plots)

def draw_folds(log_paths, key_names):
    paths = [log_path for log_path in log_paths]
    
    plots = [get_plot_data(log_path, key_names, i) for (i, log_path) in enumerate(paths)]

    for i, plot in enumerate(plots):
        layout = go.Layout(title='Fold '+str(i))

        fig = go.Figure(data=plot, layout=layout)
        iplot(fig)

def compare_folds(log_paths, key_names):
    paths = [log_path for log_path in log_paths]
    plots = [get_plot_data(log_path, key_names, i) for (i, log_path) in enumerate(paths)]

    fig = tools.make_subplots(rows=len(plots[0]), cols=1, specs=[[{}], [{}], [{}]],
                          shared_xaxes=True, shared_yaxes=False,
                          vertical_spacing=0.1)

    for i, plot in enumerate(plots):
        for j in range(len(plot)):
            fig.append_trace(plot[j], j+1, 1)
    fig['layout'].update(height=300*len(plots), width=800, showlegend=False)
    iplot(fig)

In [None]:
exp_path = '../data/experiments/densenet121-folds-001/'
log_file = 'log.txt'
key_names=['Train', 'Validation']

folds_paths = [os.path.join(os.path.join(exp_path, fold_name), log_file) for fold_name in os.listdir(exp_path)
               if os.path.isdir(os.path.join(exp_path, fold_name))]

In [None]:
draw_log(folds_paths[0], key_names)

In [None]:
draw_folds(folds_paths, key_names)

In [None]:
compare_folds(folds_paths, key_names)