In [None]:
import numpy as np
import matplotlib.pyplot as plt
import glob
import os

In [None]:
# Plotting Helpers

from __future__ import division, print_function, absolute_import

import numpy as np
from matplotlib.colors import ColorConverter
from matplotlib import rcParams


__author__ = 'Felix Berkenkamp / Johannes Kirschner'
__all__ = ['set_figure_params', 'emulate_color', 'linewidth_in_data_units',
           'adapt_figure_size_from_axes', 'cm2inches', 'hide_all_ticks',
           'hide_spines', 'set_frame_properties']


def emulate_color(color, alpha=1, background_color=(1, 1, 1)):
    """Take an RGBA color and an RGB background, return the emulated RGB color.
    The RGBA color with transparency alpha is converted to an RGB color via
    emulation in front of the background_color.
    """
    to_rgb = ColorConverter().to_rgb
    color = to_rgb(color)
    background_color = to_rgb(background_color)
    return [(1 - alpha) * bg_col + alpha * col
            for col, bg_col in zip(color, background_color)]


def cm2inches(centimeters):
    """Convert cm to inches"""
    return centimeters / 2.54


def set_figure_params(serif=True):
    """Define default values for font, fontsize and use latex
    Parameters
    ----------
    serif: bool, optional
        Whether to use a serif or sans-serif font
    """

    params = {
              'font.family': 'serif',
              'font.serif': ['Times',
                             'Palatino',
                             'New Century Schoolbook',
                             'Bookman',
                             'Computer Modern Roman'],
              'font.sans-serif': ['Times',
                                  'Helvetica',
                                  'Avant Garde',
                                  'Computer Modern Sans serif'],
              'text.usetex': True,
              # Make sure mathcal doesn't use the Times style
              'text.latex.preamble':
              r'\DeclareMathAlphabet{\mathcal}{OMS}{cmsy}{m}{n}',

              'axes.labelsize': 9,
              'axes.linewidth': .75,

              'font.size': 8,
              'legend.fontsize': 8,
              'xtick.labelsize': 8,
              'ytick.labelsize': 8,

              # 'figure.dpi': 150,
              # 'savefig.dpi': 600,
              # 'legend.numpoints': 1,
              }

    if not serif:
        params['font.family'] = 'sans-serif'

    rcParams.update(params)


def hide_all_ticks(axis):
    """Hide all ticks on the axis.
    Parameters
    ----------
    axis: matplotlib axis
    """
    axis.tick_params(axis='both',        # changes apply to the x-axis
                     which='both',       # affect both major and minor ticks
                     bottom='off',       # ticks along the bottom edge are off
                     top='off',          # ticks along the top edge are off
                     left='off',         # No ticks left
                     right='off',        # No ticks right
                     labelbottom='off',  # No tick-label at bottom
                     labelleft='off')    # No tick-label at bottom


def hide_spines(*axes, top=True, right=True):
    
    for axis in axes:
        """Hide the top and right spine of the axis."""
        if top:
            axis.spines['top'].set_visible(False)
            axis.xaxis.set_ticks_position('bottom')
        if right:
            axis.spines['right'].set_visible(False)
            axis.yaxis.set_ticks_position('left')


def set_frame_properties(axis, color, lw):
    """Set color and linewidth of frame."""
    for spine in axis.spines.values():
        spine.set_linewidth(lw)
        spine.set_color(color)


def linewidth_in_data_units(linewidth, axis, reference='y'):
    """
    Convert a linewidth in data units to linewidth in points.
    Parameters
    ----------
    linewidth: float
        Linewidth in data units of the respective reference-axis
    axis: matplotlib axis
        The axis which is used to extract the relevant transformation
        data (data limits and size must not change afterwards)
    reference: string
        The axis that is taken as a reference for the data width.
        Possible values: 'x' and 'y'. Defaults to 'y'.
    Returns
    -------
    linewidth: float
        Linewidth in points
    """
    fig = axis.get_figure()

    if reference == 'x':
        # width of the axis in inches
        axis_length = fig.get_figwidth() * axis.get_position().width
        value_range = np.diff(axis.get_xlim())
    elif reference == 'y':
        axis_length = fig.get_figheight() * axis.get_position().height
        value_range = np.diff(axis.get_ylim())

    # Convert axis_length from inches to points
    axis_length *= 72

    return (linewidth / value_range) * axis_length


def adapt_figure_size_from_axes(axes):
    """
    Adapt the figure sizes so that all axes are equally wide/high.
    When putting multiple figures next to each other in Latex, some
    figures will have axis labels, while others do not. As a result,
    having the same figure width for all figures looks really strange.
    This script adapts the figure sizes post-plotting, so that all the axes
    have the same width and height.
    Be sure to call plt.tight_layout() again after this operation!
    This doesn't work if you have multiple axis on one figure and want them
    all to scale proportionally, but should be an easy extension.
    Parameters
    ----------
    axes: list
        List of axes that we want to have the same size (need to be
        on different figures)
    """
    # Get parent figures
    figures = [axis.get_figure() for axis in axes]

    # get axis sizes [0, 1] and figure sizes [inches]
    axis_sizes = np.array([axis.get_position().size for axis in axes])
    figure_sizes = np.array([figure.get_size_inches() for figure in figures])

    # Compute average axis size [inches]
    avg_axis_size = np.average(axis_sizes * figure_sizes, axis=0)

    # New figure size is the average axis size plus the white space that is
    # not begin used by the axis so far (e.g., the space used by labels)
    new_figure_sizes = (1 - axis_sizes) * figure_sizes + avg_axis_size

    # Set new figure sizes
    for figure, size in zip(figures, new_figure_sizes):
        figure.set_size_inches(size)
        
def plot_errorbars(axis, x, y, err, num=10, sigma=2, **kwargs):
    T = len(y)
    axis.errorbar(x[T//(2*num):], y[T//(2*num):], sigma*err[T//(2*num):], 
                  ecolor=emulate_color('black', alpha=0.6),
                     capsize=2,
                     capthick=0.3,
                     elinewidth=0.3,
                     linewidth=0.6,
                     fmt="none",
                     errorevery=T//num,
                     **kwargs)

In [None]:
# define tableau colors:
tableau20 = [(31, 119, 180), (174, 199, 232), (255, 127, 14), (255, 187, 120),  
             (44, 160, 44), (152, 223, 138), (214, 39, 40), (255, 152, 150),  
             (148, 103, 189), (197, 176, 213), (140, 86, 75), (196, 156, 148),  
             (227, 119, 194), (247, 182, 210), (127, 127, 127), (199, 199, 199),  
(188, 189, 34), (219, 219, 141), (23, 190, 207), (158, 218, 229)]

for i in range(len(tableau20)):  
    r, g, b = tableau20[i]  
    tableau20[i] = (r / 255., g / 255., b / 255.)

plt.figure(figsize=(10,1))
for i,c in enumerate(tableau20):
    plt.bar(i,1, color=c)

    
plt.xticks([*range(20)])
plt.show()

In [None]:
# data base directory
basedir = os.path.join(os.getcwd(), '../runs2/')
plotdir = os.path.join(os.getcwd(), '../plots/')


TEXTWIDTH = 15.23782  # width from latex paper

# labels and colors
LABELS = {
    'ids-full' : 'IDS',
    'ucb' : 'UCB',
    'ucb-s2' : 'UCB-log(s2)',
    'ids-directed2' :'IDS-direct',
    'asymptotic_ids' : 'Asymptotic IDS',
    'asymptotic_ids2' : 'Asymptotic IDS',

}
COLORS = {
    'ids-full' : tableau20[0],
    'ucb' : tableau20[2],
    'ucb-s2' : tableau20[3],
    'ids-directed2' : tableau20[4],
    'ids-info_game' : tableau20[6],
    'asymptotic_ids' : tableau20[8],
    'asymptotic_ids2' : tableau20[9]

}

In [None]:
def load_regret_data(env_path):
    """ helper function to load aggregated data """
    data = dict()
    
    for path in glob.iglob(os.path.join(basedir, env_path, '*/'), recursive=False):
        strategy = path.rsplit(os.sep, maxsplit=2)[1]
        strategy_dict = dict()
        
        # get all files that where created for regret aggregation, there might be multiple
        aggr_regret_files = [os.path.split(f)[1] for f in glob.iglob(path + 'aggr-regret-*.csv')]
        
        if len(aggr_regret_files) == 0:
            continue
        
        # split of merge count from file name
        aggr_regret_count = [int(f[:-4].rsplit('-', maxsplit=1)[1]) for f in aggr_regret_files]

        # target file is the aggregation file based on the most runs
        aggr_regret_file = aggr_regret_files[np.argmax(aggr_regret_count)]
        aggr_count = np.max(aggr_regret_count)
        
        print(f"Reading {path[len(basedir):]}{aggr_regret_file} with {aggr_count} runs.")
        
        # store data
        strategy_dict['regret'] = np.loadtxt(os.path.join(path, aggr_regret_file))
        strategy_dict['repetitions'] = aggr_count
        
        data[strategy] = strategy_dict
        
    return data

In [None]:
def plot_regret(data, axis, strategies=None, n=None):
    if strategies is None:
        strategies = data.keys()
    for strategy in strategies:
        strategy_data = data[strategy]
        regret = strategy_data['regret']
        if n is None:
           n = len(regret)
        x=np.arange(n)
        rep = strategy_data['repetitions']
        axis.plot(x, regret[:n, 0], label=LABELS[strategy], color=COLORS[strategy], linewidth=1)
        plot_errorbars(axis, x, regret[:n, 0], regret[:n, 1]/np.sqrt(rep), num=10, sigma=2)
        
#         axis.errorbar(np.arange(n), regret[:, 1]/np.sqrt(strategy_data['repetitions']))

In [None]:
simple_bandit = load_regret_data('simple_bandit-100003')

In [None]:
simple_bandit.keys()

In [None]:
fig, axis = plt.subplots(ncols=3, figsize=(12,4))
plot_regret(simple_bandit, axis[0])
axis[0].legend()

plot_regret(simple_bandit, axis[1], n=10000)
axis[1].legend()

plot_regret(simple_bandit, axis[2])
axis[2].legend()
axis[2].set_xscale('log')

fig.suptitle('Results from a bandit with 6 arms and large gap (0.9),\n showing same data on different horizon')
fig.savefig(os.path.join(plotdir, 'simple_bandit_3.pdf'))
