In [51]:
%%capture stored_output
%load_ext autoreload
%autoreload 2


# Visualize results

In [52]:
import os
import pandas as pd
import seaborn as sns
import datetime
from ipywidgets import interact, widgets
from IPython.display import display, clear_output
from typing import List

global tables, algos, palette_dict, metrics
global plot_output, load_output, env_w, algo_w, scorers_w, aggregate_w, save_w
tables = {}
algos = []
metrics = []
palette_dict = {}


In [53]:
def plot_df(b):
    with plot_output:
        try:
            env_name = env_w.value
            algo_list = algo_w.value
            scorers_list = scorers_w.value
            aggregate = aggregate_w.value
            save = save_w.value
            
            df = tables[env_name]
            df = df.drop(columns =['env'], errors='ignore') 
            df = df.loc[df['algo'].isin(algo_list)]
            df = df.loc[df['metric'].isin(scorers_list)]
            df = df.reset_index()

            print("Plotting {} with {} and {}, aggregate: {}".format(env_name, algo_list, scorers_list, aggregate))

            if (aggregate):
                df = df.drop(columns =['date-time'], errors='ignore')

                plot = sns.relplot(
                    data=df,
                    x = "epoch",
                    y = "value",
                    col="metric",
                    hue="algo",
                    style="algo", 
                    kind="line",
                    height=5,
                    aspect=1.7,
                    dashes = False,
                    palette=palette_dict,
                )

            else:

                plot = sns.FacetGrid(
                    data=df,
                    col="metric",
                    hue="algo",
                    row="date-time",   
                    height=5,
                    aspect=1.7,
                    legend_out=True,
                    sharey=False,
                    sharex=False,
                    palette=palette_dict,
                )
                plot.map_dataframe(sns.lineplot, "epoch", "value")
                plot.add_legend()
            
            if save:
                plot.savefig("{}_{}.png".format(env_name,datetime.datetime.now().strftime("%d%m%Y_%H%M")))
            print("Done!")
        except ValueError:
                print("It was not possible to plot the requested data")
    return

In [54]:
def load_data(b):
    global tables, algos, palette_dict, metrics
    clear_output(wait=False)
    with load_output:
        
        for filename in os.scandir("results"):
            if filename.name.endswith('.parquet'):
                env = filename.name.split('.')[0]      #get name of the environment
                df = pd.read_parquet(filename.path)    #get dataframe from parquet
                tables[env] = df

        #get list of used algorithms
        algos = [algo for table in tables.values() for algo in table["algo"].unique()]
        algos = list(dict.fromkeys(algos))
        algos.sort()
        
        #get list of used scorers
        metrics = [scorer for table in tables.values() for scorer in table["metric"].unique()]
        metrics = list(dict.fromkeys(metrics))
        metrics.sort()
        print(metrics)

        #construct colors palette (fix color for each algo)
        palette_dict = {
            continent: color
            for continent, color in zip(algos, sns.color_palette("tab10"))
        }
    run_visualizer()
    return

In [None]:
def run_visualizer():
    global plot_output, load_output, env_w, algo_w, scorers_w, aggregate_w, save_w
    
    load_butt = widgets.Button(
        description="Refresh",
        button_style='warning',
        icon='fa-refresh'#'check',
    )
    load_output = widgets.Output()
    load_butt.on_click(load_data)

    plot_butt = widgets.Button(
        description="Show results",
        button_style='success',
        icon='fa-line-chart',
    )
    plot_output = widgets.Output()
    plot_butt.on_click(plot_df)

    env_w = widgets.ToggleButtons(
        options=list(tables.keys()),
        description='Environment: ',
        disabled=False,
        button_style='info', # 'success', 'info', 'warning', 'danger' or ''
        continuous_update=True
    )

    algo_w = widgets.SelectMultiple(
        options=algos,
        description='Algorithms: ',
        layout={'width': 'max-content'},
        disabled=False,
        rows=10,
        continuous_update=True
    )
    
    scorers_w = widgets.SelectMultiple(
        options=metrics,
        description='Scorers: ',
        layout={'width': 'max-content'},
        disabled=False,
        rows=6,
        continuous_update=True
    )

    aggregate_w = widgets.Checkbox(
        value=False,
        description='Aggregate trials? ',
        disabled=False,
        indent=False

    )
    
    save_w = widgets.Checkbox(
        value=False,
        description='Save graph? ',
        disabled=False,
        indent=False

    )
    
    display(load_butt, load_output)
    display(env_w)
    display(algo_w)
    display(scorers_w)
    display(aggregate_w)
    display(save_w)
    display(plot_butt, plot_output)
    

In [None]:
run_visualizer()