In [None]:
from pathlib import Path
from collections import namedtuple, defaultdict
from statistics import mean

import pandas as pd
from bokeh.plotting import output_notebook, figure, ColumnDataSource, show
from bokeh.models import NumeralTickFormatter
from bokeh.palettes import Category10_3 as PALETTE

Trial = namedtuple('Trial', 'attrs, data')

def read_data(path):
    grouped = defaultdict(list)
    for filepath in path.glob('*.csv'):
        attrs = dict(attr.split('=') for attr in filepath.stem.split(','))
        for key, value in attrs.items():
            if value.isdigit():
                attrs[key] = int(value)
        key = tuple(sorted((k, v) for k, v in attrs.items() if k != 'random_seed'))
        df = pd.read_csv(
            filepath,
            sep=' ',
            names=['timestamp', 'episode', 'reward'],
            usecols=['episode', 'reward'],
            index_col='episode',
        )
        grouped[key].append(Trial(attrs, df))
    return {
        key: pd.concat([trial.data for trial in group], axis=1).mean(axis=1).reset_index(name='reward')
        for key, group in grouped.items()
    }

output_notebook()

## Experiment 1

In [None]:
def plot_experiment_1():
    data = read_data(Path('results/experiment1'))
    fig = figure(
        width=1200, height=800,
        x_axis_label='Episode', x_range=[0, 150_000],
        y_axis_label='Total Reward', y_range=[-80, 5],
    )
    fig.line(
        x=[-500_000, 500_000],
        y=-10,
        line_width=5,
        color='#C0C0C0',
        line_dash=[10, 30],
    )
    for index, organization in enumerate(['date', 'artist', 'country']):
        key = (('data_file', f'album_{organization}'),)
        rolling_size = 51
        df = data[key]
        df = df.rolling(window=rolling_size, center=True).mean()
        df = df.rolling(window=rolling_size, center=True).mean()
        fig.line(
            x='episode',
            y='reward',
            source=ColumnDataSource(df),
            color=PALETTE[index],
            line_width=5,
            legend_label=('decade' if organization == 'date' else organization),
        )
    fig.xaxis[0].formatter = NumeralTickFormatter(format="0a")
    fig.outline_line_color = None
    fig.axis.axis_line_width = 5
    fig.axis.axis_label_text_font_size = '40pt'
    fig.axis.major_label_text_font_size = '40pt'
    fig.axis.major_label_text_font = 'times'
    fig.grid.grid_line_color = None
    fig.axis.minor_tick_in = 0
    fig.axis.minor_tick_out = 0
    fig.legend.location = 'bottom_right'
    fig.legend.label_text_font_size='30pt'
    fig.legend.glyph_width = 100
    show(fig)
        
plot_experiment_1()

## Experiment 2

In [None]:
def plot_experiment_2():
    data = read_data(Path('results/experiment2'))
    fig = figure(
        width=1200, height=800,
        x_axis_label='Episode', x_range=[0, 310_000],
        y_axis_label='Total Reward', y_range=[-80, 5],
    )
    fig.line(
        x=[-500_000, 500_000],
        y=-10,
        line_width=5,
        color='#C0C0C0',
        line_dash=[10, 30],
    )
    dashes = {
        ('naive', 0): [1, 0],
        ('kb', 1): [10, 10],
        ('kb', 4): [30, 20],
    }
    for agent, internals in [('naive', 0), ('kb', 1), ('kb', 4)]:
        key = (
            ('agent_type', agent),
            ('max_internal_actions', internals),
        )
        rolling_size = 51
        df = data[key]
        df = df.rolling(window=rolling_size, center=True).mean()
        #df = df.rolling(window=rolling_size, center=True).mean()
        fig.line(
            x='episode',
            y='reward',
            source=ColumnDataSource(df),
            #color=PALETTE[index],
            line_width=5,
            line_dash=dashes[(agent, internals)],
            legend_label=('naive' if agent == 'naive' else f'kb-{internals}'),
        )
    fig.xaxis[0].formatter = NumeralTickFormatter(format="0a")
    fig.outline_line_color = None
    fig.axis.axis_line_width = 5
    fig.axis.axis_label_text_font_size = '40pt'
    fig.axis.major_label_text_font_size = '40pt'
    fig.axis.major_label_text_font = 'times'
    fig.grid.grid_line_color = None
    fig.axis.minor_tick_in = 0
    fig.axis.minor_tick_out = 0
    fig.legend.location = 'bottom_right'
    fig.legend.label_text_font_size='30pt'
    fig.legend.glyph_width = 100
    show(fig)
        
plot_experiment_2()

## Experiment 3

In [None]:
def plot_experiment_3():
    data = read_data(Path('results/experiment3'))
    fig = figure(
        width=1200, height=800,
        x_axis_label='Episode', x_range=[0, 310_000],
        y_axis_label='Total Reward', y_range=[-80, 5],
    )
    fig.line(
        x=[-500_000, 500_000],
        y=-10,
        line_width=5,
        color='#C0C0C0',
        line_dash=[10, 30],
    )
    key = ()
    rolling_size = 51
    df = data[key]
    df = df.rolling(window=rolling_size, center=True).mean()
    fig.line(
        x='episode',
        y='reward',
        source=ColumnDataSource(df),
        #color=color,
        line_width=5,
    )
    fig.xaxis[0].formatter = NumeralTickFormatter(format="0a")
    fig.outline_line_color = None
    fig.axis.axis_line_width = 5
    fig.axis.axis_label_text_font_size = '40pt'
    fig.axis.major_label_text_font_size = '40pt'
    fig.axis.major_label_text_font = 'times'
    fig.grid.grid_line_color = None
    fig.axis.minor_tick_in = 0
    fig.axis.minor_tick_out = 0
    show(fig)
        
plot_experiment_3()