In [1]:
import pandas as pd
import plotly.express as px
from plotly.subplots import make_subplots
import plotly.graph_objects as go
import pathlib

from modules.plots import *

In [2]:
METRICS = ['PSNR', 'SSIM', 'NIQE', 'Ma', 'PI']
EXPERIMENTS = ['e01-8', 'e01-6', 'e01-4', 'e01-3']
#EXPERIMENTS = ['e01-4', 'e01-6']
LOGS_DIR = 'logs'

PLOTLY_COLORS = get_plotly_standard_colors('hex')

In [3]:
def final_epoch_read_csv(csv_path, metrics_filter):
    if isinstance(csv_path, str):
        csv_path = pathlib.Path(csv_path)
    df = pd.read_csv(csv_path)
    
    # Allow for metrics filter to include metrics not in this particular df
    metrics_set = set()
    metrics_set.update(metrics_filter)
    columns_set = set()
    columns_set.update(list(df.columns))
    metrics_filter = list(metrics_set.intersection(columns_set))
    
    # add tile_path to the list of columns to extract
    column_filter = []
    column_filter.append('ms_tile_path')
    [column_filter.append(column) for column in metrics_filter]
    df = df.loc[:, column_filter]
    return df

In [4]:
def path_to_columns(row):
    #print(row)
    path = row['ms_tile_path']
    #print(path)
    for town in ['_Toulon', '_La_Spezia_']:
        if town in path:
            #print(area)
            row['area'] = town.strip('_')
    for sensor in ['WV02', 'GE01']:
        if sensor in path:
            row['sensor'] = sensor
    for val_test in ['val', 'test']:
        if val_test in path:
            row['val_test'] = val_test
    row['image_UID'] = path[path.find('\\\\' + row['sensor'])+2:path.find('\\\\ms\\\\')]
    row['tile_UID'] = path[path.find('\\\\ms\\\\',)+6:-1]
    return row

In [5]:
dfs = {}
for experiment in EXPERIMENTS:
    csv_paths = list(pathlib.Path(LOGS_DIR).joinpath(experiment).joinpath('csv').glob('*.csv'))
    for csv_path in csv_paths:
        name = experiment + '-' + csv_path.stem
        print(name)
        dfs[name] = final_epoch_read_csv(csv_path, metrics_filter=METRICS)
        dfs[name] = dfs[name].apply(path_to_columns, axis=1)


e01-6-final_epoch-gan-val-WV02
e01-6-final_epoch-pre-val-WV02
e01-4-final_epoch-gan-val-GE01
e01-4-final_epoch-gan-val-WV02
e01-4-final_epoch-pre-val-GE01
e01-4-final_epoch-pre-val-WV02
e01-3-final_epoch-gan-val-GE01
e01-3-final_epoch-gan-val-WV02
e01-3-final_epoch-pre-val-GE01
e01-3-final_epoch-pre-val-WV02


In [6]:
def metric_histogram(dfs, metric):
    fig = make_subplots(
        rows=2, cols=2, 
        horizontal_spacing=0.04, 
        vertical_spacing=0.04,
        shared_xaxes=True,
        shared_yaxes=True,
        x_title=metric,
        subplot_titles=['WV02', 'GE01']
        )
    legend_colors = {}
    color_count = 0
    for df_name in dfs.keys():
        if 'pre' in df_name:
            row = 1
        else:
            row = 2
        if 'WV02' in df_name:
            col = 1
        else:
            col = 2
        #print(df_name, row, col)
        legendgroup = df_name[:5]# experiment
        if metric in dfs[df_name].keys():
            if legendgroup not in legend_colors.keys():
                legend_colors[legendgroup] = color_count
                fig.add_trace(go.Histogram(x=dfs[df_name][metric], 
                                           legendgroup=legendgroup, 
                                           name=legendgroup,
                                           marker_color=PLOTLY_COLORS[legend_colors[legendgroup]]), 
                              row=row, col=col)
                color_count += 1
            else:
                fig.add_trace(go.Histogram(x=dfs[df_name][metric], 
                                           legendgroup=legendgroup, 
                                           showlegend=False,
                                           marker_color=PLOTLY_COLORS[legend_colors[legendgroup]]), 
                              row=row, col=col)
            #fig.add_vline(x=dfs[df_name][metric].mean(), line=dict(color=PLOTLY_COLORS[legend_colors[legendgroup]]))
    fig.update_layout(title=metric)
    fig.update_yaxes(title='Pretrained (L1)', row=1, col=1)
    fig.update_yaxes(title='GAN-trained', row=2, col=1)
    fig.show()
metric_histogram(dfs, metric='PSNR')
#metric_histogram(dfs, metric='SSIM')
metric_histogram(dfs, metric='NIQE')
metric_histogram(dfs, metric='Ma')
metric_histogram(dfs, metric='PI')

In [7]:
dfs['e01-4-final_epoch-gan-val-GE01']

Unnamed: 0,ms_tile_path,Ma,NIQE,PSNR,PI,SSIM,area,sensor,val_test,image_UID,tile_UID
0,b'data\\toulon-laspezia-tiles\\e01\\val\\GE01_...,4.765611,3.386027,27.399662,4.310208,0.855645,La_Spezia,GE01,val,GE01_La_Spezia_2012_02_23_011651192010_0,00000.tif
1,b'data\\toulon-laspezia-tiles\\e01\\val\\GE01_...,2.911900,9.239561,42.147961,8.163831,0.995304,La_Spezia,GE01,val,GE01_La_Spezia_2012_02_23_011651192010_0,00001.tif
2,b'data\\toulon-laspezia-tiles\\e01\\val\\GE01_...,4.402814,4.528114,29.690521,5.062650,0.681053,La_Spezia,GE01,val,GE01_La_Spezia_2012_02_23_011651192010_0,00003.tif
3,b'data\\toulon-laspezia-tiles\\e01\\val\\GE01_...,4.131922,3.796207,33.996761,4.832143,0.888635,La_Spezia,GE01,val,GE01_La_Spezia_2012_02_23_011651192010_0,00004.tif
4,b'data\\toulon-laspezia-tiles\\e01\\val\\GE01_...,4.418512,4.643028,36.718353,5.112258,0.920875,La_Spezia,GE01,val,GE01_La_Spezia_2012_02_23_011651192010_0,00005.tif
...,...,...,...,...,...,...,...,...,...,...,...
3784,b'data\\toulon-laspezia-tiles\\e01\\val\\GE01_...,4.721834,4.867767,23.924448,5.072967,0.876824,Toulon,GE01,val,GE01_Toulon 2019_10_07_011651194010_0,00510.tif
3785,b'data\\toulon-laspezia-tiles\\e01\\val\\GE01_...,4.441097,5.039841,23.078400,5.299372,0.856587,Toulon,GE01,val,GE01_Toulon 2019_10_07_011651194010_0,00511.tif
3786,b'data\\toulon-laspezia-tiles\\e01\\val\\GE01_...,4.838021,4.038345,16.675341,4.600162,0.463720,Toulon,GE01,val,GE01_Toulon 2019_10_07_011651194010_0,00512.tif
3787,b'data\\toulon-laspezia-tiles\\e01\\val\\GE01_...,4.480389,3.765590,17.889948,4.642601,0.311867,Toulon,GE01,val,GE01_Toulon 2019_10_07_011651194010_0,00513.tif


In [8]:
def histograms_color_facet(df, color='town', facet='image_UID'):
    fig = px.histogram(df, x='PI', 
                       color=color, 
                       #facet_col = facet, 
                       #facet_col_wrap=4)
                      )
    #fig.for_each_annotation(lambda a: a.update(text=''))
    fig.show()
    
histograms_color_facet(dfs['e01-4-final_epoch-gan-val-WV02'], 
                       color='area', 
                       facet='image_UID')
histograms_color_facet(dfs['e01-4-final_epoch-gan-val-GE01'], 
                       color='area', 
                       facet='image_UID')

In [16]:
def compute_statistics(dfs, statistic='mean', decimals=2):
    stats = {}
    for df_name in dfs.keys():
        row_name = df_name.replace('final_epoch-', '')
        if statistic == 'mean':
            stats[row_name] = dfs[df_name].mean()
        elif statistic == 'median':
            stats[row_name] = dfs[df_name].median()
        elif statistic == 'std':
            stats[row_name] = dfs[df_name].std()
        elif statistic == 'count':
            stats[row_name] = dfs[df_name].count()
        elif statistic == 'sem':
            stats[row_name] = dfs[df_name].sem()

    stats = pd.DataFrame.from_dict(stats).transpose().round(decimals)
    return stats

df_mean = compute_statistics(dfs, statistic='mean', decimals={'PSNR': 2, 'SSIM': 3, 'NIQE': 2, 'Ma': 2, 'PI': 2})
#df_median = compute_statistics(dfs, statistic='median')
#df_std = compute_statistics(dfs, statistic='std')
df_count = compute_statistics(dfs, statistic='count')

# sem = standard error of the mean (with n-1)
df_sem = compute_statistics(dfs, statistic='sem', decimals={'PSNR': 2, 'SSIM': 3, 'NIQE': 2, 'Ma': 2, 'PI': 2}) 

In [17]:
df_mean.to_csv(LOGS_DIR + '/e01-means.csv')
df_sem.to_csv(LOGS_DIR + '/e01-sems.csv')
df_count.to_csv(LOGS_DIR + '/e01-counts.csv')

In [20]:
df_mean

Unnamed: 0,Ma,NIQE,PI,PSNR,SSIM
e01-6-gan-val-WV02,,4.93,,34.45,0.797
e01-6-pre-val-WV02,,8.53,,35.43,0.818
e01-4-gan-val-GE01,4.08,5.31,5.62,31.95,0.783
e01-4-gan-val-WV02,4.46,4.89,5.22,33.51,0.806
e01-4-pre-val-GE01,,8.23,,33.18,0.812
e01-4-pre-val-WV02,,8.33,,34.78,0.834
e01-3-gan-val-GE01,,5.02,,31.43,0.757
e01-3-gan-val-WV02,,4.69,,32.56,0.765
e01-3-pre-val-GE01,,8.33,,31.44,0.76
e01-3-pre-val-WV02,,8.4,,33.86,0.799
