In [None]:
import os
import fnmatch
import time

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

%matplotlib notebook

In [None]:
# description of charts
# 2 possible ways:
# - chart_name : list of channel names - for every name on list plot data from <name>.csv
# - chart_name : 'per_worker' - plot data from every chart_name_<number>.csv file
charts_description = {
    'score' : ['score_mean', 'score_max', 'online_score'],
    'loss' : ['cost', 'policy_loss', 'xentropy_loss', 'value_loss', 'advantage', 'pred_reward', 'max_logit'],
    'gradients' : ['grad_norm_before_clip', 'grad_norm_after_clip'],
    'active relus' : ['active_relus'],
    'delay' : ['max_delay', 'mean_delay', 'min_delay'],
    'other' : ['active_workers', 'dp_per_s'],
    
    'cost' : 'per_worker',
    'policy_loss' : 'per_worker',
    'xentropy_loss' : 'per_worker',
    'value_loss' : 'per_worker',
    
    'mean_value' : 'per_worker',
    'mean_action' : 'per_worker',
    'mean_state' : 'per_worker',
    'mean_futurereward' : 'per_worker',
    'mean_init_R' : 'per_worker',
    
    'fc_value' : 'per_worker',
    'fc_fc0' : 'per_worker',
}

In [None]:
# here write path to experiment dir
EXP_DIR = '/home/ajedrych/Documents/experiments/exp16_4/'

if EXP_DIR[-1] != '/':
    EXP_DIR += '/'

In [None]:
class Channel:
    def __init__(self, x=None, y=None, filepath=None, title=None):
        if filepath is not None:
            dt = pd.read_csv(filepath, delimiter=',')
            if len(dt) == 0:
                # no data yet
                self._x = []
                self._y = []
            else:
                self._x = dt.x
                self._y = dt.y
        else:
            self._x = x
            self._y = y
            
        self._title = title
    
    @property
    def x(self):
        return self._x
    
    @property
    def y(self):
        return self._y
    
    @property
    def title(self):
        return self._title
    
class Chart:
    def __init__(self, channels, title, lcols=5):
        self._channels = channels
        self._title = title
        self._lcols = lcols
        
    def __getindex__(self, index):
        return channels[index]
    
    @property
    def title(self):
        return self._title
    
    @property
    def lcols(self):
        return self._lcols
    
    def channels(self):
        for channel in self._channels:
            yield channel
    

In [None]:
def draw_chart(fig, nrows, ncols, i, chart):
    ax = fig.add_subplot(nrows, ncols, i)
    ax.title.set_text(chart.title)
    for channel in chart.channels():
        ax.plot(channel.x, channel.y, label=channel.title)
    ax.legend(loc='upper left',
          fancybox=True, ncol=chart.lcols)

In [None]:
def to_chart(chart_name, channel_names):   
    channels = []
    
    if channel_names == 'per_worker':
        name = chart_name + '_'
        for f in os.listdir(EXP_DIR):
            if fnmatch.fnmatch(f, name + '[0-9]*.csv'):
                filepath = EXP_DIR + f
                channels.append(Channel(filepath=filepath, title=f[(f.rfind('_')+1):-4]))
        return Chart(channels, title=chart_name, lcols=9)
    else:
        for name in channel_names:
            filename = EXP_DIR + name + '.csv'
            channels.append(Channel(filepath=filename, title=name))
        return Chart(channels, title=chart_name)

In [None]:
# read files
charts = {}
for chart_name in charts_description:
    charts[chart_name] = to_chart(chart_name, charts_description[chart_name])

In [None]:
def draw_single_chart(chart):
    fig = plt.figure()
    draw_chart(fig, 1, 1, 1, chart)
    fig.show()

In [None]:
for chart_name in charts:
    draw_single_chart(charts[chart_name])