In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np

In [None]:
# path = '../output_data/test_output/P3o2fc.h5'
path = '../output_data/P3o2fc.h5'
with pd.HDFStore(path) as store:
    print(store.info())
    data = store.select(key='dat_0')
data.index = data.index.droplevel(['test'])
data.reset_index()
data['Y'] = data['Y_c'] + data['Y_d']
data_mean = data.groupby(['fully_connected', 'tstep', 'phi']).mean()
data_std = data.groupby(['fully_connected', 'tstep', 'phi']).std()

In [None]:
def make_fig(ncols=2, nrows=2, figsize=(9,9)):
    fig, axes = plt.subplots(ncols=ncols, nrows=nrows)
    try:
        axes = [ax for l in axes for ax in l]
    except TypeError:
        pass   
    fig.set_figwidth(figsize[0])
    fig.set_figheight(figsize[1])
    return fig, axes



def plot_observable(ax, observable, 
                    data, 
                    error=None,
                    colors=None,
                    legend=False):
    
    if colors is None:
        import matplotlib.colors as mcolors
        colors = [c for _, c in mcolors.TABLEAU_COLORS.items()]
        colors[3:6] = colors[0:3]

    
    tmp_data = data[observable]
    index_names = [i for i in tmp_data.index.names if i != 'tstep']
    tmp_data = tmp_data.unstack(index_names)
    
    if error is not None:
        tmp_error = error[observable]
        tmp_error = tmp_error.unstack(index_names)
    
    
    for i, col in enumerate(tmp_data.columns.values):
        style = '--' if col[0] else '-'
        
        lbl = f'$\\varphi$ = {col[1]}' if legend and not col[0] else None
        
        if error is not None:
            plot_line_with_error(ax, 
                             tmp_data[col], 
                             tmp_error[col], 
                             color=colors[i], 
                             label=lbl,
                             style=style)
        else:
            plot_line(ax, 
                  tmp_data[col],
                  color=colors[i], 
                  label=lbl,
                  style=style)
    if legend:
        ax.legend()
    ax.set_ylabel(observable)
    

def plot_line_with_error(ax, data, error, color, label, style):
    """plot data with error bars"""
    t = data.index.values
    x = data.values
    dx = error.values
    
    ax.plot(t, x, color=color, label=label)
    ax.fill_between(t, x+dx, x-dx, color=color, alpha=.1)
    ax.plot(t, x+dx, color=color, alpha=.2)
    ax.plot(t, x-dx, color=color, alpha=.2)
    
def plot_line(ax, data, color, label, style='-'):
    """plot data with error bars"""
    t = data.index.values
    x = data.values
    
    ax.plot(t, x, style, color=color, label=label)

In [None]:
fig, axes = make_fig(nrows=1, figsize=(8,3))

observables = ['N_c over N', '[cd] over M']


rolling_mean_window=1
lgd = 0
for ax, observable in zip(axes, observables):
    plot_observable(ax, observable, 
                    data_mean.rolling(rolling_mean_window).mean(), 
#                     data_std.rolling(rolling_mean_window).mean(),
                    legend=True if lgd==1 else False
                   )
    lgd+=1
    ax.set_xlim([300, 700])
    ax.set_xlabel('t')
    
fig.tight_layout()
fig.savefig('compare_fc_and_adaptive.pdf')