# Result data analyses from complex training
### Data fetched iether locally or from gdrive via oauth

In [None]:
# !rm -f *.png

## Compile Form in order to run the notebook

Forms provide an easy way to parameterize code. From a code cell, select Insert → Add form field. When you change the value in a form, the corresponding value in the code will change.

In [None]:
#@title Compile this form to run notebook:
import os
import warnings
warnings.filterwarnings("ignore", message="Numerical issues were encountered ")

#@markdown ---
#@markdown ##### Enter Root path:
root_path = "." # @param [".", /content/", "/content/drive/My Drive/Siren Deep Learning Analyses/results"]

#@markdown ---
#@markdown ##### Enter trial number:
train_no =  0  #@param {type:"integer", min:0, max:23, step:1}

#@markdown ---
#@markdown ##### Toggle checkbox to download resulting pictures:
download_pictures_checkbox = False #@param {type:"boolean"}

fetch_data_from_gdrive_checkbox = True #@param {type:"boolean"}

#@markdown ---
#@markdown ##### Enter dirname and image name:
dir_image = "/content/drive/My Drive/Siren Deep Learning Analyses/testsets/BSD68" # @param ["/content/", "/content/drive/My Drive/Siren Deep Learning Analyses/testsets/BSD68"]
image_name = "test068.png" #@param {type:"string"}

basedir_path_out_images = f"mixed_out_train_{train_no}" # @oaram ["/content", ".", out_train_{train_no}]

# if you get the shareable link, the link contains this id, replace the file_id below
file_id = ''

In [None]:
dates_input = ';;'.split(";")[0:3]
train_timestamps = ";;".split(";")[0:3]
trains_no = ";;".split(";")[0:3]
file_ids = ";;".split(";")[0:3]

## Authentication phase

In [None]:
from apiclient import discovery
from httplib2 import Http
import oauth2client
from oauth2client import file, client, tools
obj = lambda: None
lmao = {"auth_host_name":'localhost', 'noauth_local_webserver':'store_true', 'auth_host_port':[8080, 8090], 'logging_level':'ERROR'}
for k, v in lmao.items():
    setattr(obj, k, v)
    
# authorization boilerplate code
SCOPES = 'https://www.googleapis.com/auth/drive.readonly'
store = file.Storage('token.json')
creds = store.get()
# The following will give you a link if token.json does not exist, the link allows the user to give this app permission
if not creds or creds.invalid:
    flow = client.flow_from_clientsecrets('client_id.json', SCOPES)
    creds = tools.run_flow(flow, store, obj)

## Setup

### Installations

In [None]:
# !pip install -q gwpy

In [None]:
# Clean /content from trash or old .png images
# !rm -f /content/*.png

### Imports

In [None]:
from datetime import datetime
# from google.colab import files

from pathlib import Path
from collections import namedtuple

# import psycopg2 as ps
import matplotlib.pyplot as plt
plt.style.use('dark_background')
import seaborn as sns
# sns.set_theme(style="whitegrid")
import ipywidgets as widgets
# back end of ipywidgets
from IPython.display import display

import io
from googleapiclient.http import MediaIoBaseDownload
import zipfile

import collections
import itertools
import functools
import glob
import operator
import os
import re
import numpy as np
import pandas as pd

from PIL import Image

# skimage
import skimage
import skimage.metrics as skmetrics
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim
from skimage.metrics import mean_squared_error

from sklearn.linear_model import LinearRegression
from sklearn.preprocessing import PolynomialFeatures

### Setup output images full path

In [None]:
try: os.makedirs(basedir_path_out_images)
except: pass

In [None]:
image_kind_str = "df_scatter;scatter;bar;reg;point;box;violin;complex"
images_kind = list(map(lambda xx: f"{xx}plot", filter(lambda xx: len(xx) != 0, sorted(image_kind_str.split(";")))))

ImagesConf = namedtuple('ImagesConf', images_kind)

In [None]:
half_name = f"mse_psnr_et_al_vs_no_params_train_no_{train_no}.png"
def full_path_out_images(item, root_path = basedir_path_out_images, half_name = half_name):
    return os.path.join(root_path, f"{item}_{half_name}")

# image_names = list(map(lambda xx: f"{xx}_{half_name}", images_kind))
image_names = list(map(full_path_out_images, images_kind))

In [None]:
images_conf = ImagesConf._make(image_names)

### Functions

In [None]:
def compute_graph_image_psnr_CR(data_tuples, x_axes, y_axes, subject, colors = sns.color_palette()):   
    # Prepare pairs of attributes to be represented
    # one against the other via scatter plot.
    # x_axes = "bpp;file_size_bits".split(";")
    # y_axes = "psnr;CR".split(";")

    pairs_axes = list(itertools.product(x_axes, y_axes))

    # Settle figure grid.
    axes_list = None
    fig, axes = plt.subplots(len(x_axes), len(y_axes), figsize=(20, 10))
    try:
        axes_list = functools.reduce(operator.iconcat, axes, [])
    except:
        axes_list = axes
        pass

    # Compute graph.
    for ii, (ax, pair_axes) in enumerate(zip(axes_list, pairs_axes)):
        # Prepare data.
        x_axis, y_axis = pair_axes[0], pair_axes[1]
        x = np.array(list(map(lambda item: getattr(item, f"{x_axis}"), data_tuples)))
        y = np.array(list(map(lambda item: getattr(item, f"{y_axis}"), data_tuples)))
        # Create Chart.
        ax.scatter(x, y, marker = 'x', color = colors[ii], label = f'{subject} - {y_axis}')
        # ax.set_xscale('symlog')
        # ax.set_yscale('symlog')
        ax.set_ylabel(f'{y_axis}')
        ax.set_xlabel(f'{x_axis}')
        ax.legend()
        ax.set_title(f'{y_axis.upper()} vs. {x_axis.upper()}')
        pass
    return fig, axes

In [None]:
def graphics_bars_pointplot(dataframe, y_axes, x_axis, grid_shape, palette="Blues_d", axes = None, figsize = (15, 5), show_fig = False, title = 'Complex Plot'):
    fig, axes = plt.subplots(*grid_shape, figsize=figsize)
    fig.suptitle(f'{title}', fontsize=15)

    with warnings.catch_warnings():
        warnings.simplefilter('ignore')
        pos = 0
        try:
            axes_list = functools.reduce(operator.iconcat, axes, [])
        except:
            axes_list = axes

        _ = graphics_scatterplot(
            dataframe = dataframe,
            y_axes = y_axes,
            axes = axes_list[len(y_axes) * pos:len(y_axes) * (pos+1)],
            x_axis = x_axis)
        """
        for ax in axes_list[len(y_axes) * pos:len(y_axes) * (pos+1)]:
            ax.get_xaxis().set_visible(False)
        """
        pos += 1

        _ = graphics_bars_mean_std(
            dataframe = dataframe,
            y_axes = y_axes,
            axes = axes_list[len(y_axes) * pos:len(y_axes) * (pos+1)],
            x_axis = x_axis)
    
        for ax in axes_list[len(y_axes) * pos:len(y_axes) * (pos+1)]:
            ax.get_xaxis().set_visible(False)
        pos += 1
    
        _ = graphics_pointplot_mean_std(
            dataframe = dataframe,
            y_axes = y_axes,
            axes = axes_list[len(y_axes) * pos:len(y_axes) * (pos+1)],
            x_axis = x_axis)

        for ax in axes_list[len(y_axes) * pos:len(y_axes) * (pos+1)]:
            ax.get_xaxis().set_visible(False)
        pos += 1
    
        _ = graphics_regplot_mean_std(
            dataframe = dataframe,
            y_axes = y_axes,
            axes = axes_list[len(y_axes) * pos:len(y_axes) * (pos+1)],
            x_axis = x_axis)
    
        for ax in axes_list[len(y_axes) * pos:len(y_axes) * (pos+1)]:
            ax.get_xaxis().set_visible(False)
        pos += 1

        _ = graphics_boxplot(
            dataframe = dataframe,
            y_axes = y_axes,
            axes = axes_list[len(y_axes) * pos:len(y_axes) * (pos+1)],
            x_axis = x_axis)


        for ax in axes_list[len(y_axes) * pos:len(y_axes) * (pos+1)]:
            ax.get_xaxis().set_visible(False)
        pos += 1
    
        _ = graphics_violinplot(
            dataframe = dataframe,
            y_axes = y_axes,
            axes = axes_list[len(y_axes) * pos:len(y_axes) * (pos+1)],
            x_axis = x_axis)
        pass
    return fig, axes

In [None]:
def graphics_scatterplot(dataframe, y_axes, x_axis, grid_shape = None, palette="Blues_d", axes = None, figsize = (15, 5), show_fig = False, title = 'Complex Plot'):
    flag = False
    fig = None
    if axes is None:
        fig, axes = plt.subplots(*grid_shape, figsize=figsize)
        fig.suptitle(f'{title}', fontsize=15)
        flag = True
        pass

    data_xtick_arr = \
        np.array(
            np.unique(dataframe[f"{x_axis}"].values),
            dtype=np.int
    )

    try:
        axes_list = functools.reduce(operator.iconcat, axes, [])
    except:
        axes_list = axes
    for ii, (ax, y_axis) in enumerate(zip(axes_list, y_axes)):
        # _ = sns.regplot(x=f"{x_axis}", y=(f"{y_axis}"), data=dataframe, order=1, ax = ax, marker = 'x', color = 'black', label = 'poly order 1°')
        # _ = sns.regplot(x=f"{x_axis}", y=(f"{y_axis}"), data=dataframe, order=2, ax = ax, marker = 'x', color = 'black', label = 'poly order 2°')

        _ = sns.scatterplot(x=f"{x_axis}", y=(f"{y_axis}"), data=dataframe, ax = ax, marker = 'x', color = sns.color_palette()[ii])
        # axes[0].get_yaxis().set_visible(False)
        ax.set_title(f'{y_axis.upper()}', fontsize=10)
        # ax.set_xticklabels(data_xtick_arr, rotation=45)
        # ax.set_xticklabels(data_xtick_arr, rotation=45)
        ax.set_xscale('log')
        pass

    # plt.tight_layout()
    if flag is False:
        return axes
    else:
        # plt.tight_layout()
        if show_fig: plt.show()
        return fig, axes

In [None]:
def graphics_violinplot(dataframe, y_axes, x_axis, grid_shape = None, palette="Blues_d", axes = None, figsize = (15, 5), show_fig = False, title = 'Complex Plot'):
    flag = False
    fig = None
    if axes is None:
        fig, axes = plt.subplots(*grid_shape, figsize=figsize)
        fig.suptitle(f'{title}', fontsize=15)
        flag = True
        pass

    data_xtick_arr = \
        np.array(
            np.unique(dataframe[f"{x_axis}"].values),
            dtype=np.int
    )

    try:
        axes_list = functools.reduce(operator.iconcat, axes, [])
    except:
        axes_list = axes
    for ax, y_axis in zip(axes_list, y_axes):
        _ = sns.violinplot(x=f"{x_axis}", y=(f"{y_axis}"), data=dataframe, ax = ax, palette="Set3", bw=.2, cut=1, linewidth=1)
        # axes[0].get_yaxis().set_visible(False)
        ax.set_title(f'{y_axis.upper()}', fontsize=10)
        ax.set_xticklabels(data_xtick_arr, rotation=45)
        # ax.set_yscale('log')
        pass

    # plt.tight_layout()
    if flag is False:
        return axes
    else:
        # plt.tight_layout()
        if show_fig: plt.show()
        return fig, axes

In [None]:
def graphics_boxplot(dataframe, y_axes, x_axis, grid_shape = None, palette="Blues_d", axes = None, figsize = (15, 5), show_fig = False, title = 'Complex Plot'):
    flag = False
    fig = None
    if axes is None:
        fig, axes = plt.subplots(*grid_shape, figsize=figsize)
        fig.suptitle(f'{title}', fontsize=15)
        flag = True
        pass

    data_xtick_arr = \
        np.array(
            np.unique(dataframe[f"{x_axis}"].values),
            dtype=np.int
    )

    try:
        axes_list = functools.reduce(operator.iconcat, axes, [])
    except:
        axes_list = axes
    for ax, y_axis in zip(axes_list, y_axes):
        _ = sns.boxplot(x=f"{x_axis}", y=(f"{y_axis}"),
            data=dataframe,
            palette=palette, ax = ax)
        # axes[0].get_yaxis().set_visible(False)
        ax.set_title(f'{y_axis.upper()}', fontsize=10)
        ax.set_xticklabels(data_xtick_arr, rotation=45)
        # ax.set_yscale('log')
        pass

    # plt.tight_layout()
    if flag is False:
        return axes
    else:
        # plt.tight_layout()
        if show_fig: plt.show()
        return fig, axes

In [None]:
def graphics_bars_mean_std(dataframe, y_axes, x_axis, grid_shape = None, palette="Blues_d", axes = None, figsize = (15, 5), show_fig = False, title = 'Complex Plot'):
    flag = False
    fig = None
    if axes is None:
        fig, axes = plt.subplots(*grid_shape, figsize=figsize)
        fig.suptitle(f'{title}', fontsize=15)
        flag = True
        pass

    data_xtick_arr = \
        np.array(
            np.unique(dataframe[f"{x_axis}"].values),
            dtype=np.int
    )

    try:
        axes_list = functools.reduce(operator.iconcat, axes, [])
    except:
        axes_list = axes
    for ax, y_axis in zip(axes_list, y_axes):
        _ = sns.barplot(x=f"{x_axis}", y=(f"{y_axis}"),
            data=dataframe,
            palette=palette,
            capsize=.0, ax = ax)
        # axes[0].get_yaxis().set_visible(False)
        ax.set_title(f'{y_axis.upper()} (mean+std)', fontsize=10)
        ax.set_xticklabels(data_xtick_arr, rotation=45)

    # plt.tight_layout()
    if flag is False:
        return axes
    else:
        # plt.tight_layout()
        if show_fig: plt.show()
        return fig, axes

In [None]:
def graphics_pointplot_mean_std(dataframe, y_axes, x_axis, grid_shape = None, palette=None, axes = None, figsize = (15, 5), show_fig = False, title = 'Complex Plot'):
    flag = False
    fig = None
    if axes is None:
        fig, axes = plt.subplots(*grid_shape, figsize=figsize)
        fig.suptitle(f'{title}', fontsize=15)
        flag = True
        pass

    data_xtick_arr = \
        np.array(
            np.unique(results_history_sorted_df[f"{x_axis}"].values),
            dtype=np.int
    )

    try:
        axes_list = functools.reduce(operator.iconcat, axes, [])
    except:
        axes_list = axes
    for ax, y_axis in zip(axes_list, y_axes):
        
        _ = sns.pointplot(x=f"{x_axis}", y=(f"{y_axis}"),
            data=dataframe,
            palette=palette,
            capsize=.0, ax = ax)
        # axes[0].get_yaxis().set_visible(False)
        ax.set_title(f'{y_axis.upper()} (mean+std)', fontsize=10)
        # ax.set_xticklabels(data_xtick_arr, rotation=45)
        ax.set_xscale('log')
        pass
    
    if flag is False:
        return axes
    else:
        # plt.tight_layout()
        if show_fig: plt.show()
        return fig, axes

In [None]:
def graphics_regplot_mean_std(dataframe, y_axes, x_axis, grid_shape = None, palette=None, axes = None, figsize = (15, 5), show_fig = False, title = 'Complex Plot'):
    flag = False
    fig = None
    if axes is None:
        fig, axes = plt.subplots(*grid_shape, figsize=figsize)
        fig.suptitle(f'{title}', fontsize=15)
        flag = True
        pass

    data_xtick_arr = \
        np.array(
            np.unique(dataframe[f"{x_axis}"].values),
            dtype=np.int
    )

    try:
        axes_list = functools.reduce(operator.iconcat, axes, [])
    except:
        axes_list = axes
    for ax, y_axis in zip(axes_list, y_axes):
        """
        _ = sns.regplot(x=f"{x_axis}", y=(f"{y_axis}"),
            data=dataframe, color = 'red', label = 'y_axis.upper()', ax = ax)
        """
        _ = sns.regplot(x=f"{x_axis}", y=(f"{y_axis}"), data=dataframe,
                label = f'{y_axis.upper()}',
                # scatter_kws={"s": 80},
                x_estimator=np.mean,
                ax = ax,
                order=4, ci=68)
        # axes[0].get_yaxis().set_visible(False)
        ax.set_title(f'{y_axis.upper()} | poly-regression order 4°', fontsize=10)
        # ax.set_xticklabels(data_xtick_arr, rotation=45)

    
    if flag is False:
        return axes
    else:
        # plt.tight_layout()
        if show_fig: plt.show()
        return fig, axes

## Code

### Fetch Data

In [None]:
def adjust_date_format(date_input):
    return '-'.join([xx for xx in date_input.split('-')[::-1]])
date_inputs_tmp = list(map(adjust_date_format, dates_input))
print(date_inputs_tmp)

train_timestamps_tmp = [train_timestamp.replace('.', '-') for train_timestamp in train_timestamps]
print(train_timestamps_tmp)

trains_datetime = [os.path.join(date_input_tmp, train_timestamp_tmp)
                   for date_input_tmp, train_timestamp_tmp in zip(date_inputs_tmp, train_timestamps_tmp)]
print(trains_datetime)
print(trains_no)

In [None]:
# print('Date train:', train_datetime)
path_history_trains = []
def adjust_trains_path(root_path, trains_no, trains_datetime = None):
    if root_path == "/content":
        path_history_trains = [os.path.join(
            root_path,
            f'result_comb_train_{train_no}.txt')
        for train_no in trains_no]
    elif root_path == ".":
        path_history_trains = [os.path.join(
            root_path,
            f'result_comb_train_{train_no}.txt')
        for train_no in trains_no]
    else:
        path_history_trains = [os.path.join(
            root_path,
            train_datetime,
            "train",
            f'result_comb_train_{train_no}.txt')
            for train_datetime, train_no in zip(trains_datetime, trains_no)]
    return path_history_trains

path_history_trains = adjust_trains_path(root_path, trains_no)
print("Path location:")
print(path_history_trains)

In [None]:
def fetch_data_from_gdrive(path_history_train, file_id, fetch_data_from_gdrive_checkbox):
    if fetch_data_from_gdrive_checkbox:
        if os.path.exists(f'{path_history_train}') is False:
            DRIVE = discovery.build('drive', 'v3', http=creds.authorize(Http()))
            request = DRIVE.files().get_media(fileId=file_id)

            # replace the filename and extension in the first field below
            # fh = io.FileIO(f'filename.zip', mode='w')
            fh = io.FileIO(f'{path_history_train}', mode='w')
            downloader = MediaIoBaseDownload(fh, request)
            done = False
            while done is False:
                status, done = downloader.next_chunk()
                print("Download %d%%." % int(status.progress() * 100))
                pass
            pass
        else:
            print(f"Already exists: {path_history_train}")
        pass
    pass

for path_history_train, file_id in zip(path_history_trains, file_ids):
    print(path_history_train)
    fetch_data_from_gdrive(path_history_train, file_id, fetch_data_from_gdrive_checkbox)
    pass

In [None]:
columns_df = ['#params', 'seed', 'hl', 'hf', 'mse', 'psnr', 'ssim', 'train_eta']

results_history_arr = None
for path_history_train in path_history_trains:
    print(path_history_train)
    if results_history_arr is None:
        results_history_arr = np.loadtxt(path_history_train)
        # print(results_history_arr)
    else:
        try:
            tmp_arr = np.loadtxt(path_history_train)
            # print(tmp_arr)
            results_history_arr = np.concatenate((results_history_arr, tmp_arr), axis = 0)
        except:
            tmp_arr = np.loadtxt(path_history_train)
            # print(tmp_arr)
            results_history_arr = np.concatenate((results_history_arr, [tmp_arr]), axis = 0)
            pass
        pass
    pass

results_history_df = pd.DataFrame(
    data = results_history_arr,
    columns = columns_df)

### Dataframe: brief description

In [None]:
results_history_df.head(5)

In [None]:
results_history_df.info()

In [None]:
results_history_df.describe()

In [None]:
collections.Counter(results_history_df["hf"].values)

### Dataframe: in depth description

#### Scatter - Plot

In [None]:
# sns.pairplot(results_history_df)
g = sns.PairGrid(results_history_df.drop(['hf', 'hl', 'seed'], axis = 1), diag_sharey=False)
g.map_upper(sns.scatterplot, s=15) # 
g.map_lower(sns.kdeplot)
g.map_diag(sns.kdeplot, lw=2)
# plt.savefig(f'scatter_plot_train_no_{train_no}.png')
plt.savefig(f'{images_conf.df_scatterplot}')

#### Plots

In [None]:
# Compute Mean and Standard-Deviation for MSE and PSNR
# with respect to data grouped by means of'#params'-attribute
#  within dataframe object: 'results_history_df

results_history_df.groupby(by = ['#params'])[['mse', 'psnr', 'ssim', 'train_eta']]\
    .describe()[[
                ("mse", "mean"), ("mse", "std"),
                ("psnr", "mean"), ("psnr", "std"),
                ("ssim", "mean"), ("ssim", "std"),
                ("train_eta", "mean"), ("train_eta", "std")
                ]]

In [None]:
grid_shape = "(2, 2)" #@param ["(1, 4)", "(4, 1)", "(2, 2)"]
grid_shape = eval(grid_shape)

fig, axes = graphics_scatterplot(
    dataframe = results_history_df,
    y_axes = ("mse", "psnr", "ssim", "train_eta"),
    x_axis = "#params",
    grid_shape = grid_shape,
    figsize = (20, 10))
fig.suptitle('Trend MSE and PSNR et al. across archs (grouped by #params).', fontsize=15)
# plt.savefig(f"scatterplot_mse_psnr_et_al_vs_no_params_train_no_{train_no}.png")
plt.savefig(f"{images_conf.scatterplot}")

plt.show()

In [None]:
# %%capture
results_history_sorted_df = results_history_df.sort_values(by=['#params', 'hf', 'hl'])

grid_shape = "(1, 4)" #@param ["(1, 4)", "(4, 1)", "(2, 2)"]
grid_shape = eval(grid_shape)

fig, axes = graphics_bars_mean_std(
    dataframe = results_history_sorted_df,
    y_axes = ("mse", "psnr", "ssim", "train_eta"),
    x_axis = "#params",
    grid_shape = grid_shape,
    figsize = (20, 5))
fig.suptitle('Trend MSE and PSNR et al. across archs (grouped by #params).', fontsize=15)
# plt.savefig(f"bar_plot_mse_psnr_et_al_grouped_by_params_train_no_{train_no}.png")
plt.savefig(f"{images_conf.barplot}")
plt.show()

In [None]:
results_history_sorted_df = results_history_df.sort_values(by=['#params', 'hf', 'hl'])

grid_shape = "(1, 4)" #@para m ["(1, 4)", "(4, 1)", "(2, 2)"]
grid_shape = eval(grid_shape)

fig, axes = graphics_pointplot_mean_std(
    dataframe = results_history_sorted_df,
    y_axes = ("mse", "psnr", "ssim", "train_eta"),
    x_axis = "#params",
    grid_shape = grid_shape, 
    figsize = (20, 5))
fig.suptitle('Trend MSE and PSNR et al. across archs (grouped by #params).', fontsize=15)
# plt.savefig(f"pointplot_mse_psnr_et_al_grouped_by_params_train_no_{train_no}.png")
plt.savefig(f"{images_conf.pointplot}")
plt.subplots_adjust(bottom = 0.2)
plt.show()

In [None]:
results_history_sorted_df = results_history_df.sort_values(by=['#params', 'hf', 'hl'])

grid_shape = "(1, 4)" #@param ["(1, 4)", "(4, 1)", "(2, 2)"]
grid_shape = eval(grid_shape)

fig, axes = graphics_boxplot(
    dataframe = results_history_sorted_df,
    y_axes = ("mse", "psnr", "ssim", "train_eta"),
    x_axis = "#params",
    grid_shape = grid_shape,
    figsize = (20, 5))
fig.suptitle('Trend MSE and PSNR et al. across archs (grouped by #params).', fontsize=15)
# plt.savefig(f"boxplot_mse_psnr_et_al_grouped_by_params_train_no_{train_no}.png")
plt.savefig(f"{images_conf.boxplot}")
plt.subplots_adjust(bottom = 0.2)
plt.show()

In [None]:
results_history_sorted_df = results_history_df.sort_values(by=['#params', 'hf', 'hl'])

grid_shape = "(1, 4)" #@param ["(1, 4)", "(4, 1)", "(2, 2)"]
grid_shape = eval(grid_shape)

fig, axes = None, None
with warnings.catch_warnings():
    warnings.simplefilter('ignore')
    fig, axes = graphics_violinplot(
        dataframe = results_history_sorted_df,
        y_axes = ("mse", "psnr", "ssim", "train_eta"),
        x_axis = "#params",
        grid_shape = grid_shape,
        figsize = (20, 5))
    pass
fig.suptitle('Trend MSE and PSNR et al. across archs (grouped by #params).', fontsize=15)
# plt.savefig(f"violinplot_mse_psnr_et_al_grouped_by_params_train_no_{train_no}.png")
plt.savefig(f"{images_conf.violinplot}")
plt.subplots_adjust(bottom = 0.2)
plt.show()

In [None]:
results_history_sorted_df = results_history_df.sort_values(by=['#params', 'hf', 'hl'])

grid_shape = "(1, 4)" #@param ["(1, 4)", "(4, 1)", "(2, 2)"]
grid_shape = eval(grid_shape)

fig, axes = None, None
with warnings.catch_warnings():
    warnings.simplefilter('ignore')
    fig, axes = graphics_regplot_mean_std(
        dataframe = results_history_sorted_df,
        y_axes = ("mse", "psnr", "ssim", "train_eta"),
        x_axis = "#params",
        grid_shape = grid_shape,
        figsize = (20, 10))
    pass
fig.suptitle('Trend MSE and PSNR et al. across archs (grouped by #params).', fontsize=15)
# plt.savefig(f"regplot_mse_psnr_et_al_grouped_by_params_train_no_{train_no}.png")
plt.savefig(f"{images_conf.regplot}")
plt.subplots_adjust(bottom = 0.2)
plt.show()

### Summary Graph

In [None]:
fig, axes = None, None
with warnings.catch_warnings():
    warnings.simplefilter('ignore')
    fig, axes = graphics_bars_pointplot(
        dataframe = results_history_sorted_df,
        y_axes = ("mse", "psnr", "ssim", "train_eta"),
        x_axis = "#params",
        grid_shape = (6, 4),
        figsize = (20, 20), palette="Blues_d",
        show_fig = False,
        title = 'Complex Plot')
    pass
fig.suptitle('Trend MSE and PSNR et al. across archs (grouped by #params).', fontsize=15)
# plt.savefig(f"complex_mse_psnr_et_al_grouped_by_params_train_no_{train_no}.png")
plt.savefig(f"{images_conf.complexplot}")
plt.subplots_adjust(bottom = 0.2)
# plt.tight_layout()
plt.show()

### Download files

In [None]:
print(f"Pictures ({len(list(Path(basedir_path_out_images).glob('*.png')))}):")
for path in Path(basedir_path_out_images).glob('*.png'):
    target_file = os.path.join(basedir_path_out_images, path.name)
    print(f"{target_file}")
    # if download_pictures_checkbox: files.download(target_file)
    pass

### Compare obtained results between JPEG and Siren

In [None]:
# Load target image.
image_file_path = 'test068.png'
im = Image.open(f'{image_file_path}')
print('Image size:', im.size)
im

In [None]:
# Test how to manually crop image from its center.
width, height = im.size 

left = width - height
top = 0
right = width
bottom = height

im.crop((left, top, right, bottom))

In [None]:
# Array of qualities to be tested in compression.
qualities_arr = np.arange(1, 99+1, dtype = np.int)

In [None]:
# Run several trials for JPEG compression.

# Named tuple for creating a record related to
# a trial for compressing the target image.
name_obj = 'WeightsPsnr'
attributes = ['psnr', 'quality', 'file_size_bits', 'bpp', 'width', 'heigth', 'CR']
WeightsPsnr = collections.namedtuple(name_obj, attributes) 

# List used to save results and keep trace of failures, if any.
result_tuples = []
failure_qualities = []

# Gather results.
# Firstly crop image to desired shape.
for edges in [(left, top, right, bottom)]: # for edges in edges_list:
    
    left, top, right, bottom = list(map(int, edges))
    
    # Then test the effect of several different quality values
    # used in compression transform.
    for quality in qualities_arr:
        try:
            # Convert to JPEG specifying quality of compression.
            im_tmp = im.crop((left, top, right, bottom))
            im_tmp.save(f'myimg.jpg', quality = int(quality))
            im_jpeg = Image.open('myimg.jpg')
            assert im_jpeg.size == im_tmp.size, "im_jpeg.size != im_tmp.size"
    
            # Calculate quantities to be stored for this trial
        
            width, height = im_jpeg.size[0], im_jpeg.size[1]
            pixels = width * height
            file_size_bits = Path('myimg.jpg').stat().st_size * 8
            original_file_size_bits = Path(image_file_path).stat().st_size * 8
            
    
            
            bpp = file_size_bits / pixels    
            psnr_score = psnr(np.asarray(im_tmp), np.asarray(im_jpeg), data_range=255)
            CR = original_file_size_bits / file_size_bits
            
            # Store results into a list
            values = [psnr_score, quality, file_size_bits, bpp, width, height, CR]
            result_tuples.append(WeightsPsnr._make(values))
        except Exception as err:
            # Keep track of unaccepted quality values for compressing the image
            print(err)
            failure_qualities.append(quality)
        pass
    pass

In [None]:
# Show calculated PSNR vs. # Bits.

# Prepare Data.
# x = np.arange(0, len(result_tuples), dtype=np.int)
x = np.array(list(map(lambda item: getattr(item, "file_size_bits"), result_tuples)))
y = np.array(list(map(lambda item: getattr(item, "psnr"), result_tuples)))

# Show plot.
fig = plt.figure()
plt.scatter(x, y, marker = 'x', color = sns.color_palette()[1], label = 'jpeg')
# plt.xscale('log')
plt.ylabel('PSNR')
plt.xlabel('# Bits')
plt.legend()
plt.title('PSNR vs. # Bits')
plt.show()

In [None]:
# Show calculated CR vs. # BPP.

# x = np.arange(0, len(result_tuples), dtype=np.int)
x = np.array(list(map(lambda item: getattr(item, "bpp"), result_tuples)))
y = np.array(list(map(lambda item: getattr(item, "CR"), result_tuples)))

fig = plt.figure()
plt.scatter(x, y, marker = 'x', color = sns.color_palette()[1], label = 'jpeg')
# plt.xscale('log')
plt.ylabel('CR')
plt.xlabel('BPP')
plt.legend()
plt.title('CR vs. BPP')
plt.show()

In [None]:
# Show calculated CR vs. # Bits.

# Prepare data.
# x = np.arange(0, len(result_tuples), dtype=np.int)
x = np.array(list(map(lambda item: getattr(item, "file_size_bits"), result_tuples)))
y = np.array(list(map(lambda item: getattr(item, "CR"), result_tuples)))

fig = plt.figure()
plt.scatter(x, y, marker = 'x', color = sns.color_palette()[1], label = 'jpeg')
# plt.xscale('log')
plt.ylabel('CR')
plt.xlabel('# Bits')
plt.legend()
plt.title('CR vs. # Bits')
plt.show()

In [None]:
# Show calculated PSNR vs. Bits.

# Prepare Data
# x = np.arange(0, len(result_tuples), dtype=np.int)
x = np.array(list(map(lambda item: getattr(item, "bpp"), result_tuples)))
y = np.array(list(map(lambda item: getattr(item, "psnr"), result_tuples)))

# Show Plot
fig = plt.figure()
plt.scatter(x, y, marker = 'x', color = sns.color_palette()[1], label = 'jpeg')
# plt.xscale('log')
plt.ylabel('PSNR')
plt.xlabel('BPP')
plt.legend()
plt.title('PSNR vs. BPP')
plt.show()

In [None]:
# Prepare pairs of attributes to be represented
# one against the other via scatter plot.
x_axes = "bpp;file_size_bits".split(";")
y_axes = "psnr;CR".split(";")

pairs_axes = list(itertools.product(x_axes, y_axes))

# Settle figure grid.
axes_list = None
fig, axes = plt.subplots(len(x_axes), len(y_axes), figsize=(20, 10))
fig.suptitle(f'JPEG', fontsize=15)
try:
    axes_list = functools.reduce(operator.iconcat, axes, [])
except:
    axes_list = axes
    pass

# Compute graph.
for ii, (ax, pair_axes) in enumerate(zip(axes_list, pairs_axes)):
    # Prepare data.
    x_axis, y_axis = pair_axes[0], pair_axes[1]
    x = np.array(list(map(lambda item: getattr(item, f"{x_axis}"), result_tuples)))
    y = np.array(list(map(lambda item: getattr(item, f"{y_axis}"), result_tuples)))
    # Create Chart.
    ax.scatter(x, y, marker = 'x', color = sns.color_palette()[ii], label = 'jpeg')
    # ax.set_xscale('symlog')
    # ax.set_yscale('symlog')
    ax.set_ylabel(f'{y_axis}')
    ax.set_xlabel(f'{x_axis}')
    ax.legend()
    ax.set_title(f'{y_axis.upper()} vs. {x_axis.upper()}')
    pass
# Show result.
plt.savefig('complex_plot_jpge_res.png')
plt.show()

In [None]:
x_axes = "bpp;file_size_bits".split(";")
y_axes = "psnr;CR".split(";")

fig, axes = compute_graph_image_psnr_CR(
    data_tuples = result_tuples,
    x_axes = x_axes,
    y_axes = y_axes,
    subject = 'jpeg',
    colors = sns.color_palette())
fig.suptitle(f'JPEG', fontsize=15)
# plt.savefig('complex_plot_jpge_res.png')
plt.show()

In [None]:
# Compare PSNR values between JPEG and Siren, by means of scatterplot.
fig = plt.figure()

# Siren results
# plt.scatter(x = np.arange(len(results_history_sorted_df['psnr'].values)), y = results_history_sorted_df['psnr'].values, marker = 'x', color = sns.color_palette()[0], label = 'siren output')
plt.scatter(x = results_history_sorted_df['#params'].values * 32,
    y = results_history_sorted_df['psnr'].values,
    marker = 'x',
    color = sns.color_palette()[0],
    label = 'siren')

# Jpeg results
x = np.array(list(map(lambda item: getattr(item, "file_size_bits"), result_tuples)))
y = np.array(list(map(lambda item: getattr(item, "psnr"), result_tuples)))
plt.scatter(x, y, marker = 'x', color = sns.color_palette()[1], label = 'jpeg')

plt.xscale('log')
plt.ylabel('PSNR')
plt.xlabel('# Bits')
plt.legend()
plt.title('PSNR vs. # Bits')
plt.show()

## References

- Wiki references:
 - [Image Compression](https://en.m.wikipedia.org/wiki/Image_compression)
 - [Bit Rate](https://en.wikipedia.org/wiki/Bit_rate)

- Generic references:
 - [FORM](https://colab.research.google.com/notebooks/forms.ipynb#scrollTo=3jKM6GfzlgpS)
 - [jupiter - themese](https://stackoverflow.com/questions/46510192/change-the-theme-in-jupyter-notebook)