In [1]:
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import pathlib

In [2]:
CSV_DIR = 'logs/csv'
METRIC_NAME = 'PSNR'
METRIC = 'PSNR'
TRAIN_VAL = 'train'  # train
#TRAIN_VAL = 'both'
TRAIN_VAL_LONG = {'train': 'Training set', 
                  'val': 'Validation set'}
TRAIN_VAL_LONG = TRAIN_VAL_LONG[TRAIN_VAL]
SMOOTH = True
SMOOTH_PAR = 0.9

In [56]:
def list_exp_csv(csv_dir, exp_name, pre_gan, train_val, metric):
    if isinstance(csv_dir, str):
        csv_dir = pathlib.Path(csv_dir)
    #csv_list = list(csv_dir.glob('**/' + exp_name + '/*' + pre_gan + '*' + train_val + '*' + metric + '*.csv'))
    csv_list = list(csv_dir.glob('**/' + exp_name + '/*' + pre_gan + '*' + metric + '*.csv'))
    return csv_list

def csv_list_to_dfs(csv_list):
    dfs = {}
    for csv_file in csv_list:
        dfs[csv_file.stem] = pd.read_csv(csv_file)
    return dfs

def merge_dfs_by_epoch(dfs, metric, wide_long, smooth=True, smooth_par=0.9):
    new_df = pd.DataFrame()
    smooth_df = pd.DataFrame()
    exp_variations = []
    for name, df in dfs.items():
        try:
            new_df['epoch']
        except KeyError:
            new_df['epoch'] = df.index + 1
            smooth_df['epoch'] = df.index + 1
        idx1 = name.find('_e') + 1
        idx2 = name.find('_2')
        exp_variation = name[idx1:idx2]
        exp_variations.append(exp_variation)
        col_name = name
        new_df[col_name] = df['Value']
        if smooth:
            smooth_df[col_name] = df['Value'].ewm(alpha=(1 - smooth_par)).mean()            
    if wide_long == 'long':
        new_df = pd.melt(new_df, id_vars=['epoch'], value_name=metric, var_name='exp_variation_id')
        if smooth:
            smooth_df = pd.melt(smooth_df, id_vars='epoch', value_name=metric, var_name='exp_variation_id')
            new_df[metric + '-smooth'] = smooth_df[metric]
            
        new_df.loc[new_df['exp_variation_id'].str.contains('train'), 'train_val'] = 'train'
        new_df.loc[new_df['exp_variation_id'].str.contains('val'), 'train_val'] = 'val'

        new_df.loc[new_df['exp_variation_id'].str.contains('WV02'), 'sensor'] = 'WV02'
        new_df.loc[new_df['exp_variation_id'].str.contains('GE01'), 'sensor'] = 'GE01'
        new_df['sensor'].fillna('WV02', inplace=True)

        for exp_variation in exp_variations:
            new_df.loc[new_df['exp_variation_id'].str.contains(exp_variation), 'exp_variation'] = exp_variation
        new_df['legend'] = new_df['exp_variation'] + '-' + new_df['sensor']
        new_df['dash'] = 'dot'
    return new_df

csv_list = list_exp_csv(CSV_DIR, 'e01', 
                        pre_gan='pre', 
                        train_val=TRAIN_VAL, 
                        metric=METRIC_NAME)
metric_df = merge_dfs_by_epoch(csv_list_to_dfs(csv_list), 
                               metric=METRIC, 
                               wide_long='long', 
                               smooth=SMOOTH, 
                               smooth_par=SMOOTH_PAR)
metric_df

Unnamed: 0,epoch,exp_variation_id,PSNR,PSNR-smooth,train_val,sensor,exp_variation,legend,dash
0,1,run-tb_e01-3-pre_20210124-153415_train-tag-epo...,16.337782,16.337782,train,WV02,e01-3-pre,e01-3-pre-WV02,dot
1,2,run-tb_e01-3-pre_20210124-153415_train-tag-epo...,25.686089,21.257943,train,WV02,e01-3-pre,e01-3-pre-WV02,dot
2,3,run-tb_e01-3-pre_20210124-153415_train-tag-epo...,28.110910,23.786713,train,WV02,e01-3-pre,e01-3-pre-WV02,dot
3,4,run-tb_e01-3-pre_20210124-153415_train-tag-epo...,29.242495,25.373158,train,WV02,e01-3-pre,e01-3-pre-WV02,dot
4,5,run-tb_e01-3-pre_20210124-153415_train-tag-epo...,30.221743,26.557154,train,WV02,e01-3-pre,e01-3-pre-WV02,dot
...,...,...,...,...,...,...,...,...,...
3995,396,run-tb_e01-8-pre_20210116-194500_val-WV02-tag-...,36.483772,35.597921,val,WV02,e01-8-pre,e01-8-pre-WV02,dot
3996,397,run-tb_e01-8-pre_20210116-194500_val-WV02-tag-...,35.439014,35.582030,val,WV02,e01-8-pre,e01-8-pre-WV02,dot
3997,398,run-tb_e01-8-pre_20210116-194500_val-WV02-tag-...,35.178223,35.541650,val,WV02,e01-8-pre,e01-8-pre-WV02,dot
3998,399,run-tb_e01-8-pre_20210116-194500_val-WV02-tag-...,35.583004,35.545785,val,WV02,e01-8-pre,e01-8-pre-WV02,dot


In [66]:
fig = px.line(metric_df, 
              x='epoch', 
              y=[METRIC + '-smooth'], 
              color='exp_variation', 
              range_y=(25,44), 
              title=METRIC, 
              facet_col='train_val',
              line_dash='sensor', 
              line_dash_sequence=['solid', 'dash']
             )
fig.update_layout(legend_title_text='Experiment variation')
fig.update_layout()
fig.show()