In [22]:
import pandas as pd
import plotly.express as px
from plotly.subplots import make_subplots
import plotly.graph_objects as go
import pathlib

from modules.plots import *

In [65]:
METRICS = ['PSNR', 'SSIM', 'NIQE']
EXPERIMENTS = ['e01-6', 'e01-4', 'e01-3']
LOGS_DIR = 'logs'

PLOTLY_COLORS = get_plotly_standard_colors('hex')

In [66]:
def final_epoch_read_csv(csv_path, metrics_filter):
    if isinstance(csv_path, str):
        csv_path = pathlib.Path(csv_path)
    df = pd.read_csv(csv_path)
    column_filter = []
    column_filter.append('ms_tile_path')
    [column_filter.append(column) for column in metrics_filter]
    df = df.loc[:, column_filter]
    return df

In [67]:
dfs = {}
for experiment in EXPERIMENTS:
    csv_paths = list(pathlib.Path(LOGS_DIR).joinpath(experiment).joinpath('csv').glob('*.csv'))
    for csv_path in csv_paths:
        name = experiment + '-' + csv_path.stem
        print(name)
        dfs[name] = final_epoch_read_csv(csv_path, metrics_filter=METRICS)

e01-6-final_epoch-gan-val-WV02
e01-6-final_epoch-pre-val-WV02
e01-4-final_epoch-gan-val-GE01
e01-4-final_epoch-gan-val-WV02
e01-4-final_epoch-pre-val-GE01
e01-4-final_epoch-pre-val-WV02
e01-3-final_epoch-gan-val-GE01
e01-3-final_epoch-gan-val-WV02
e01-3-final_epoch-pre-val-GE01
e01-3-final_epoch-pre-val-WV02


In [68]:
dfs['e01-4-final_epoch-pre-val-WV02']

Unnamed: 0,ms_tile_path,PSNR,SSIM,NIQE
0,b'data\\toulon-laspezia-tiles\\e01\\val\\WV02_...,33.853783,0.885305,6.340402
1,b'data\\toulon-laspezia-tiles\\e01\\val\\WV02_...,33.642094,0.865780,7.057770
2,b'data\\toulon-laspezia-tiles\\e01\\val\\WV02_...,37.515450,0.960813,8.777512
3,b'data\\toulon-laspezia-tiles\\e01\\val\\WV02_...,36.422112,0.898833,7.180161
4,b'data\\toulon-laspezia-tiles\\e01\\val\\WV02_...,32.609421,0.839028,6.646343
...,...,...,...,...
4319,b'data\\toulon-laspezia-tiles\\e01\\val\\WV02_...,25.573612,0.619367,7.581877
4320,b'data\\toulon-laspezia-tiles\\e01\\val\\WV02_...,29.290121,0.832043,18.241493
4321,b'data\\toulon-laspezia-tiles\\e01\\val\\WV02_...,29.115988,0.822919,18.906858
4322,b'data\\toulon-laspezia-tiles\\e01\\val\\WV02_...,21.395348,0.352485,6.647622


In [69]:
def metric_histogram(dfs, metric):
    fig = make_subplots(
        rows=2, cols=2, 
        horizontal_spacing=0.04, 
        vertical_spacing=0.04,
        shared_xaxes=True,
        shared_yaxes=True,
        x_title=metric,
        subplot_titles=['WV02', 'GE01']
        )
    experiment_color = {}
    color_count = 0
    for df_name in dfs.keys():
        if 'pre' in df_name:
            row = 1
        else:
            row = 2
        if 'WV02' in df_name:
            col = 1
        else:
            col = 2
        #print(df_name, row, col)
        experiment = df_name[:5]
        if experiment not in experiment_color.keys():
            experiment_color[experiment] = color_count
            fig.add_trace(go.Histogram(x=dfs[df_name][metric], 
                                       legendgroup=experiment, 
                                       name=experiment,
                                       marker_color=PLOTLY_COLORS[experiment_color[experiment]]), 
                          row=row, col=col)
            color_count += 1
        else:
            fig.add_trace(go.Histogram(x=dfs[df_name][metric], 
                                       legendgroup=experiment, 
                                       showlegend=False,
                                       marker_color=PLOTLY_COLORS[experiment_color[experiment]]), 
                          row=row, col=col)
    fig.update_layout(title=metric)
    fig.show()
metric_histogram(dfs, metric='PSNR')
metric_histogram(dfs, metric='SSIM')
metric_histogram(dfs, metric='NIQE')

In [115]:
def compute_statistics(dfs, statistic='mean', decimals=2):
    stats = {}
    for df_name in dfs.keys():
        row_name = df_name.replace('final_epoch-', '')
        if statistic == 'mean':
            stats[row_name] = dfs[df_name].mean()
        elif statistic == 'median':
            stats[row_name] = dfs[df_name].median()
        elif statistic == 'std':
            stats[row_name] = dfs[df_name].std()
        elif statistic == 'count':
            stats[row_name] = dfs[df_name].count()
        elif statistic == 'sem':
            stats[row_name] = dfs[df_name].sem()

    stats = pd.DataFrame.from_dict(stats).transpose().round(decimals)
    return stats

df_mean = compute_statistics(dfs, statistic='mean', decimals={'PSNR': 2, 'SSIM': 3, 'NIQE': 2})
#df_median = compute_statistics(dfs, statistic='median')
#df_std = compute_statistics(dfs, statistic='std')
df_count = compute_statistics(dfs, statistic='count')

# sem = standard error of the mean (with n-1)
df_sem = compute_statistics(dfs, statistic='sem', decimals={'PSNR': 2, 'SSIM': 3, 'NIQE': 2}) 

In [117]:
df_mean.to_csv(LOGS_DIR + '/e01-means.csv')
df_sem.to_csv(LOGS_DIR + '/e01-sems.csv')
df_count.to_csv(LOGS_DIR + '/e01-counts.csv')

In [116]:
df_count

Unnamed: 0,ms_tile_path,PSNR,SSIM,NIQE
e01-6-gan-val-WV02,4324,4324,4324,4324
e01-6-pre-val-WV02,4324,4324,4324,4324
e01-4-gan-val-GE01,3789,3789,3789,3789
e01-4-gan-val-WV02,4324,4324,4324,4324
e01-4-pre-val-GE01,3789,3789,3789,3789
e01-4-pre-val-WV02,4324,4324,4324,4324
e01-3-gan-val-GE01,3789,3789,3789,3789
e01-3-gan-val-WV02,4324,4324,4324,4324
e01-3-pre-val-GE01,3789,3789,3789,3789
e01-3-pre-val-WV02,4324,4324,4324,4324
