In [73]:
import pandas as pd
import plotly.express as px
from plotly.subplots import make_subplots
import plotly.graph_objects as go
import pathlib


In [82]:
CSV_DIR = 'logs/csv'
METRICS_NAME = ['PSNR', 'SSIM']
METRICS = ['PSNR', 'SSIM']
TRAIN_VAL = 'train'  # train
#TRAIN_VAL = 'both'
TRAIN_VAL_LONG = {'train': 'Training set', 
                  'val': 'Validation set'}
TRAIN_VAL_LONG = TRAIN_VAL_LONG[TRAIN_VAL]
SMOOTH = True
SMOOTH_PAR = 0.9

In [159]:
def list_exp_csv(csv_dir, exp_name, pre_gan, train_val, metrics):
    if isinstance(csv_dir, str):
        csv_dir = pathlib.Path(csv_dir)
    #csv_list = list(csv_dir.glob('**/' + exp_name + '/*' + pre_gan + '*' + train_val + '*' + metric + '*.csv'))
    csv_list = []
    for metric in metrics:
        csv_list.append(list(csv_dir.glob('**/' + exp_name + '/*' + pre_gan + '*' + metric + '*.csv')))
    csv_list = [item for sublist in csv_list for item in sublist]
    return csv_list

def csv_list_to_dfs(csv_list):
    dfs = {}
    for csv_file in csv_list:
        dfs[csv_file.stem] = pd.read_csv(csv_file)
    return dfs

def merge_dfs_by_epoch(dfs, metric, wide_long, smooth=True, smooth_par=0.9):
    new_df = pd.DataFrame()
    smooth_df = pd.DataFrame()
    exp_variations = []
    for name, df in dfs.items():
        if metric not in name:
            continue
        try:
            new_df['epoch']
        except KeyError:
            new_df['epoch'] = df.index + 1
            smooth_df['epoch'] = df.index + 1
        idx1 = name.find('_e') + 1
        idx2 = name.find('_2')
        exp_variation = name[idx1:idx2]
        exp_variations.append(exp_variation)
        col_name = name
        new_df[col_name] = df['Value']
        if smooth:
            smooth_df[col_name] = df['Value'].ewm(alpha=(1 - smooth_par)).mean()            
    if wide_long == 'long':
        new_df = pd.melt(new_df, id_vars=['epoch'], value_name=metric, var_name='exp_variation_id')
        if smooth:
            smooth_df = pd.melt(smooth_df, id_vars='epoch', value_name=metric, var_name='exp_variation_id')
            new_df[metric + '-smooth'] = smooth_df[metric]
            
        new_df.loc[new_df['exp_variation_id'].str.contains('train'), 'train_val'] = 'train'
        new_df.loc[new_df['exp_variation_id'].str.contains('val'), 'train_val'] = 'val'

        new_df.loc[new_df['exp_variation_id'].str.contains('WV02'), 'sensor'] = 'WV02'
        new_df.loc[new_df['exp_variation_id'].str.contains('GE01'), 'sensor'] = 'GE01'
        new_df['sensor'].fillna('WV02', inplace=True)

        for exp_variation in exp_variations:
            new_df.loc[new_df['exp_variation_id'].str.contains(exp_variation), 'exp_variation'] = exp_variation
        # new_df['legend'] = new_df['exp_variation'] + '-' + new_df['sensor']
        #new_df.set_index(['exp_variation', 'train_val', 'sensor'], inplace=True)
        new_df['metric'] = metric
    return new_df

csv_list = list_exp_csv(CSV_DIR, 'e01', 
                        pre_gan='pre', 
                        train_val=TRAIN_VAL, 
                        metrics=METRICS_NAME)
csv_list
dfs = []
for metric in METRICS:
    dfs.append(merge_dfs_by_epoch(csv_list_to_dfs(csv_list), 
                                   metric=metric, 
                                   wide_long='long', 
                                   smooth=SMOOTH, 
                                   smooth_par=SMOOTH_PAR))
metric_df = pd.concat(dfs, axis=1)
metric_df = dfs[0]
metric_df = metric_df.merge(dfs[1], how='outer')
metric_df

Unnamed: 0,epoch,exp_variation_id,PSNR,PSNR-smooth,train_val,sensor,exp_variation,metric,SSIM,SSIM-smooth
0,1,run-tb_e01-3-pre_20210124-153415_train-tag-epo...,16.337782,16.337782,train,WV02,e01-3-pre,PSNR,,
1,2,run-tb_e01-3-pre_20210124-153415_train-tag-epo...,25.686089,21.257943,train,WV02,e01-3-pre,PSNR,,
2,3,run-tb_e01-3-pre_20210124-153415_train-tag-epo...,28.110910,23.786713,train,WV02,e01-3-pre,PSNR,,
3,4,run-tb_e01-3-pre_20210124-153415_train-tag-epo...,29.242495,25.373158,train,WV02,e01-3-pre,PSNR,,
4,5,run-tb_e01-3-pre_20210124-153415_train-tag-epo...,30.221743,26.557154,train,WV02,e01-3-pre,PSNR,,
...,...,...,...,...,...,...,...,...,...,...
7995,396,run-tb_e01-8-pre_20210116-194500_val-WV02-tag-...,,,val,WV02,e01-8-pre,SSIM,0.847548,0.814846
7996,397,run-tb_e01-8-pre_20210116-194500_val-WV02-tag-...,,,val,WV02,e01-8-pre,SSIM,0.813490,0.814711
7997,398,run-tb_e01-8-pre_20210116-194500_val-WV02-tag-...,,,val,WV02,e01-8-pre,SSIM,0.800377,0.813277
7998,399,run-tb_e01-8-pre_20210116-194500_val-WV02-tag-...,,,val,WV02,e01-8-pre,SSIM,0.813178,0.813267


In [214]:
fig = px.line(metric_df, 
              x='epoch', 
              y=[METRIC[0] + '-smooth', METRIC[1] + '-smooth'], 
              color='exp_variation', 
              #range_y=(29,44), 
              title='Experiment 01 - Pretraining', 
              facet_col='train_val',
              facet_row='metric',
              line_dash='sensor', 
              line_dash_sequence=['solid', 'dash']
             )
fig.update_layout(legend_title_text='Experiment variation')
fig.for_each_annotation(lambda a: a.update(text=a.text.split("=")[-1]))
fig.update_yaxes(matches=None)
#fig.layout.yaxis1.matches = None
#fig.layout.yaxis2.matches = None
#fig.layout.yaxis3.matches = None
#fig.layout.yaxis4.matches = None
fig.layout.yaxis1.range = [0.7, 1]
fig.layout.yaxis2.range = [0.7, 1]
fig.layout.yaxis3.range = [30, 44]
fig.layout.yaxis4.range = [30, 44]

fig.show()

In [74]:
main_fig = make_subplots(rows=2, cols=1)

main_fig.

main_fig.add_trace(
    go.Scatter(x=[20, 30, 40], y=[50, 60, 70]),
    row=1, col=2
)

main_fig.update_layout(height=600, width=800, title_text="Side By Side Subplots")
main_fig.show()

ValueError: 
    Invalid element(s) received for the 'data' property of 
        Invalid elements include: [Figure({
    'data': [{'hovertemplate': ('exp_variation=e01-3-pre<br>sen' ... '}<br>value=%{y}<extra></extra>'),
              'legendgroup': 'e01-3-pre, WV02',
              'line': {'color': '#636efa', 'dash': 'solid'},
              'mode': 'lines',
              'name': 'e01-3-pre, WV02',
              'showlegend': True,
              'type': 'scattergl',
              'x': array([  1,   2,   3, ..., 398, 399, 400], dtype=int64),
              'xaxis': 'x',
              'y': array([16.33778191, 21.2579433 , 23.78671346, ..., 39.58029193, 39.58937013,
                          39.60852341]),
              'yaxis': 'y'},
             {'hovertemplate': ('exp_variation=e01-3-pre<br>sen' ... '}<br>value=%{y}<extra></extra>'),
              'legendgroup': 'e01-3-pre, WV02',
              'line': {'color': '#636efa', 'dash': 'solid'},
              'mode': 'lines',
              'name': 'e01-3-pre, WV02',
              'showlegend': False,
              'type': 'scattergl',
              'x': array([  1,   2,   3, ..., 398, 399, 400], dtype=int64),
              'xaxis': 'x2',
              'y': array([25.92531776, 27.62703614, 28.30950482, ..., 33.70764297, 33.72260598,
                          33.74346443]),
              'yaxis': 'y2'},
             {'hovertemplate': ('exp_variation=e01-3-pre<br>sen' ... '}<br>value=%{y}<extra></extra>'),
              'legendgroup': 'e01-3-pre, GE01',
              'line': {'color': '#636efa', 'dash': 'dash'},
              'mode': 'lines',
              'name': 'e01-3-pre, GE01',
              'showlegend': True,
              'type': 'scattergl',
              'x': array([  1,   2,   3, ..., 398, 399, 400], dtype=int64),
              'xaxis': 'x2',
              'y': array([19.51236725, 21.71211845, 22.69597029, ..., 31.64888061, 31.72692963,
                          31.71175296]),
              'yaxis': 'y2'},
             {'hovertemplate': ('exp_variation=e01-4-pre<br>sen' ... '}<br>value=%{y}<extra></extra>'),
              'legendgroup': 'e01-4-pre, WV02',
              'line': {'color': '#EF553B', 'dash': 'solid'},
              'mode': 'lines',
              'name': 'e01-4-pre, WV02',
              'showlegend': True,
              'type': 'scattergl',
              'x': array([  1,   2,   3, ..., 398, 399, 400], dtype=int64),
              'xaxis': 'x',
              'y': array([13.67051315, 18.83787095, 21.73972687, ..., 40.70925931, 40.71650454,
                          40.71411602]),
              'yaxis': 'y'},
             {'hovertemplate': ('exp_variation=e01-4-pre<br>sen' ... '}<br>value=%{y}<extra></extra>'),
              'legendgroup': 'e01-4-pre, WV02',
              'line': {'color': '#EF553B', 'dash': 'solid'},
              'mode': 'lines',
              'name': 'e01-4-pre, WV02',
              'showlegend': False,
              'type': 'scattergl',
              'x': array([  1,   2,   3, ..., 398, 399, 400], dtype=int64),
              'xaxis': 'x2',
              'y': array([22.51325417, 25.47343917, 25.31172312, ..., 34.80031745, 34.7822608 ,
                          34.7658538 ]),
              'yaxis': 'y2'},
             {'hovertemplate': ('exp_variation=e01-4-pre<br>sen' ... '}<br>value=%{y}<extra></extra>'),
              'legendgroup': 'e01-4-pre, GE01',
              'line': {'color': '#EF553B', 'dash': 'dash'},
              'mode': 'lines',
              'name': 'e01-4-pre, GE01',
              'showlegend': True,
              'type': 'scattergl',
              'x': array([  1,   2,   3, ..., 398, 399, 400], dtype=int64),
              'xaxis': 'x2',
              'y': array([18.23464012, 21.17246698, 21.80885817, ..., 32.97408878, 32.97277022,
                          32.97975993]),
              'yaxis': 'y2'},
             {'hovertemplate': ('exp_variation=e01-6-pre<br>sen' ... '}<br>value=%{y}<extra></extra>'),
              'legendgroup': 'e01-6-pre, WV02',
              'line': {'color': '#00cc96', 'dash': 'solid'},
              'mode': 'lines',
              'name': 'e01-6-pre, WV02',
              'showlegend': True,
              'type': 'scattergl',
              'x': array([  1,   2,   3, ..., 398, 399, 400], dtype=int64),
              'xaxis': 'x',
              'y': array([15.47518539, 20.71538283, 23.63822032, ..., 42.56808451, 42.57274762,
                          42.58219267]),
              'yaxis': 'y'},
             {'hovertemplate': ('exp_variation=e01-6-pre<br>sen' ... '}<br>value=%{y}<extra></extra>'),
              'legendgroup': 'e01-6-pre, WV02',
              'line': {'color': '#00cc96', 'dash': 'solid'},
              'mode': 'lines',
              'name': 'e01-6-pre, WV02',
              'showlegend': False,
              'type': 'scattergl',
              'x': array([  1,   2,   3, ..., 398, 399, 400], dtype=int64),
              'xaxis': 'x2',
              'y': array([25.20268059, 27.39882991, 28.34068871, ..., 35.09591809, 35.09441535,
                          35.13319883]),
              'yaxis': 'y2'},
             {'hovertemplate': ('exp_variation=e01-8-pre<br>sen' ... '}<br>value=%{y}<extra></extra>'),
              'legendgroup': 'e01-8-pre, WV02',
              'line': {'color': '#ab63fa', 'dash': 'solid'},
              'mode': 'lines',
              'name': 'e01-8-pre, WV02',
              'showlegend': True,
              'type': 'scattergl',
              'x': array([  1,   2,   3, ..., 398, 399, 400], dtype=int64),
              'xaxis': 'x',
              'y': array([17.59927177, 22.80488898, 25.57485775, ..., 42.80959141, 42.81051666,
                          42.80756748]),
              'yaxis': 'y'},
             {'hovertemplate': ('exp_variation=e01-8-pre<br>sen' ... '}<br>value=%{y}<extra></extra>'),
              'legendgroup': 'e01-8-pre, WV02',
              'line': {'color': '#ab63fa', 'dash': 'solid'},
              'mode': 'lines',
              'name': 'e01-8-pre, WV02',
              'showlegend': False,
              'type': 'scattergl',
              'x': array([  1,   2,   3, ..., 398, 399, 400], dtype=int64),
              'xaxis': 'x2',
              'y': array([25.30857086, 28.39942531, 29.38547043, ..., 35.54164964, 35.54578508,
                          35.55044081]),
              'yaxis': 'y2'}],
    'layout': {'annotations': [{'font': {},
                                'showarrow': False,
                                'text': 'train',
                                'x': 0.245,
                                'xanchor': 'center',
                                'xref': 'paper',
                                'y': 1.0,
                                'yanchor': 'bottom',
                                'yref': 'paper'},
                               {'font': {},
                                'showarrow': False,
                                'text': 'val',
                                'x': 0.755,
                                'xanchor': 'center',
                                'xref': 'paper',
                                'y': 1.0,
                                'yanchor': 'bottom',
                                'yref': 'paper'}],
               'legend': {'title': {'text': 'Experiment variation'}, 'tracegroupgap': 0},
               'template': '...',
               'title': {'text': 'PSNR'},
               'xaxis': {'anchor': 'y', 'domain': [0.0, 0.49], 'title': {'text': 'epoch'}},
               'xaxis2': {'anchor': 'y2', 'domain': [0.51, 1.0], 'matches': 'x', 'title': {'text': 'epoch'}},
               'yaxis': {'anchor': 'x', 'domain': [0.0, 1.0], 'range': [29, 44], 'title': {'text': 'value'}},
               'yaxis2': {'anchor': 'x2', 'domain': [0.0, 1.0], 'matches': 'y', 'showticklabels': False}}
})]

    The 'data' property is a tuple of trace instances
    that may be specified as:
      - A list or tuple of trace instances
        (e.g. [Scatter(...), Bar(...)])
      - A single trace instance
        (e.g. Scatter(...), Bar(...), etc.)
      - A list or tuple of dicts of string/value properties where:
        - The 'type' property specifies the trace type
            One of: ['area', 'bar', 'barpolar', 'box',
                     'candlestick', 'carpet', 'choropleth',
                     'choroplethmapbox', 'cone', 'contour',
                     'contourcarpet', 'densitymapbox', 'funnel',
                     'funnelarea', 'heatmap', 'heatmapgl',
                     'histogram', 'histogram2d',
                     'histogram2dcontour', 'image', 'indicator',
                     'isosurface', 'mesh3d', 'ohlc', 'parcats',
                     'parcoords', 'pie', 'pointcloud', 'sankey',
                     'scatter', 'scatter3d', 'scattercarpet',
                     'scattergeo', 'scattergl', 'scattermapbox',
                     'scatterpolar', 'scatterpolargl',
                     'scatterternary', 'splom', 'streamtube',
                     'sunburst', 'surface', 'table', 'treemap',
                     'violin', 'volume', 'waterfall']

        - All remaining properties are passed to the constructor of
          the specified trace type

        (e.g. [{'type': 'scatter', ...}, {'type': 'bar, ...}])